1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{self, Database, ProjectId, RoomId, User, UserId},
6 executor::Executor,
7 AppState, Result,
8};
9use anyhow::anyhow;
10use async_tungstenite::tungstenite::{
11 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
12};
13use axum::{
14 body::Body,
15 extract::{
16 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
17 ConnectInfo, WebSocketUpgrade,
18 },
19 headers::{Header, HeaderName},
20 http::StatusCode,
21 middleware,
22 response::IntoResponse,
23 routing::get,
24 Extension, Router, TypedHeader,
25};
26use collections::{HashMap, HashSet};
27pub use connection_pool::ConnectionPool;
28use futures::{
29 channel::oneshot,
30 future::{self, BoxFuture},
31 stream::FuturesUnordered,
32 FutureExt, SinkExt, StreamExt, TryStreamExt,
33};
34use lazy_static::lazy_static;
35use prometheus::{register_int_gauge, IntGauge};
36use rpc::{
37 proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
38 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
39};
40use serde::{Serialize, Serializer};
41use std::{
42 any::TypeId,
43 fmt,
44 future::Future,
45 marker::PhantomData,
46 mem,
47 net::SocketAddr,
48 ops::{Deref, DerefMut},
49 rc::Rc,
50 sync::{
51 atomic::{AtomicBool, Ordering::SeqCst},
52 Arc,
53 },
54 time::Duration,
55};
56use tokio::sync::{Mutex, MutexGuard};
57use tower::ServiceBuilder;
58use tracing::{info_span, instrument, Instrument};
59
60pub const RECONNECTION_TIMEOUT: Duration = rpc::RECEIVE_TIMEOUT;
61
62lazy_static! {
63 static ref METRIC_CONNECTIONS: IntGauge =
64 register_int_gauge!("connections", "number of connections").unwrap();
65 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
66 "shared_projects",
67 "number of open projects with one or more guests"
68 )
69 .unwrap();
70}
71
72type MessageHandler =
73 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
74
75struct Response<R> {
76 peer: Arc<Peer>,
77 receipt: Receipt<R>,
78 responded: Arc<AtomicBool>,
79}
80
81impl<R: RequestMessage> Response<R> {
82 fn send(self, payload: R::Response) -> Result<()> {
83 self.responded.store(true, SeqCst);
84 self.peer.respond(self.receipt, payload)?;
85 Ok(())
86 }
87}
88
89#[derive(Clone)]
90struct Session {
91 user_id: UserId,
92 connection_id: ConnectionId,
93 db: Arc<Mutex<DbHandle>>,
94 peer: Arc<Peer>,
95 connection_pool: Arc<Mutex<ConnectionPool>>,
96 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
97}
98
99impl Session {
100 async fn db(&self) -> MutexGuard<DbHandle> {
101 #[cfg(test)]
102 tokio::task::yield_now().await;
103 let guard = self.db.lock().await;
104 #[cfg(test)]
105 tokio::task::yield_now().await;
106 guard
107 }
108
109 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
110 #[cfg(test)]
111 tokio::task::yield_now().await;
112 let guard = self.connection_pool.lock().await;
113 #[cfg(test)]
114 tokio::task::yield_now().await;
115 ConnectionPoolGuard {
116 guard,
117 _not_send: PhantomData,
118 }
119 }
120}
121
122impl fmt::Debug for Session {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 f.debug_struct("Session")
125 .field("user_id", &self.user_id)
126 .field("connection_id", &self.connection_id)
127 .finish()
128 }
129}
130
131struct DbHandle(Arc<Database>);
132
133impl Deref for DbHandle {
134 type Target = Database;
135
136 fn deref(&self) -> &Self::Target {
137 self.0.as_ref()
138 }
139}
140
141pub struct Server {
142 peer: Arc<Peer>,
143 pub(crate) connection_pool: Arc<Mutex<ConnectionPool>>,
144 app_state: Arc<AppState>,
145 handlers: HashMap<TypeId, MessageHandler>,
146}
147
148pub(crate) struct ConnectionPoolGuard<'a> {
149 guard: MutexGuard<'a, ConnectionPool>,
150 _not_send: PhantomData<Rc<()>>,
151}
152
153#[derive(Serialize)]
154pub struct ServerSnapshot<'a> {
155 peer: &'a Peer,
156 #[serde(serialize_with = "serialize_deref")]
157 connection_pool: ConnectionPoolGuard<'a>,
158}
159
160pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
161where
162 S: Serializer,
163 T: Deref<Target = U>,
164 U: Serialize,
165{
166 Serialize::serialize(value.deref(), serializer)
167}
168
169impl Server {
170 pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
171 let mut server = Self {
172 peer: Peer::new(),
173 app_state,
174 connection_pool: Default::default(),
175 handlers: Default::default(),
176 };
177
178 server
179 .add_request_handler(ping)
180 .add_request_handler(create_room)
181 .add_request_handler(join_room)
182 .add_message_handler(leave_room)
183 .add_request_handler(call)
184 .add_request_handler(cancel_call)
185 .add_message_handler(decline_call)
186 .add_request_handler(update_participant_location)
187 .add_request_handler(share_project)
188 .add_message_handler(unshare_project)
189 .add_request_handler(join_project)
190 .add_message_handler(leave_project)
191 .add_request_handler(update_project)
192 .add_request_handler(update_worktree)
193 .add_message_handler(start_language_server)
194 .add_message_handler(update_language_server)
195 .add_message_handler(update_diagnostic_summary)
196 .add_request_handler(forward_project_request::<proto::GetHover>)
197 .add_request_handler(forward_project_request::<proto::GetDefinition>)
198 .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
199 .add_request_handler(forward_project_request::<proto::GetReferences>)
200 .add_request_handler(forward_project_request::<proto::SearchProject>)
201 .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
202 .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
203 .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
204 .add_request_handler(forward_project_request::<proto::OpenBufferById>)
205 .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
206 .add_request_handler(forward_project_request::<proto::GetCompletions>)
207 .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
208 .add_request_handler(forward_project_request::<proto::GetCodeActions>)
209 .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
210 .add_request_handler(forward_project_request::<proto::PrepareRename>)
211 .add_request_handler(forward_project_request::<proto::PerformRename>)
212 .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
213 .add_request_handler(forward_project_request::<proto::FormatBuffers>)
214 .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
215 .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
216 .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
217 .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
218 .add_message_handler(create_buffer_for_peer)
219 .add_request_handler(update_buffer)
220 .add_message_handler(update_buffer_file)
221 .add_message_handler(buffer_reloaded)
222 .add_message_handler(buffer_saved)
223 .add_request_handler(save_buffer)
224 .add_request_handler(get_users)
225 .add_request_handler(fuzzy_search_users)
226 .add_request_handler(request_contact)
227 .add_request_handler(remove_contact)
228 .add_request_handler(respond_to_contact_request)
229 .add_request_handler(follow)
230 .add_message_handler(unfollow)
231 .add_message_handler(update_followers)
232 .add_message_handler(update_diff_base)
233 .add_request_handler(get_private_user_info);
234
235 Arc::new(server)
236 }
237
238 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
239 where
240 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
241 Fut: 'static + Send + Future<Output = Result<()>>,
242 M: EnvelopedMessage,
243 {
244 let prev_handler = self.handlers.insert(
245 TypeId::of::<M>(),
246 Box::new(move |envelope, session| {
247 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
248 let span = info_span!(
249 "handle message",
250 payload_type = envelope.payload_type_name()
251 );
252 span.in_scope(|| {
253 tracing::info!(
254 payload_type = envelope.payload_type_name(),
255 "message received"
256 );
257 });
258 let future = (handler)(*envelope, session);
259 async move {
260 if let Err(error) = future.await {
261 tracing::error!(%error, "error handling message");
262 }
263 }
264 .instrument(span)
265 .boxed()
266 }),
267 );
268 if prev_handler.is_some() {
269 panic!("registered a handler for the same message twice");
270 }
271 self
272 }
273
274 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
275 where
276 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
277 Fut: 'static + Send + Future<Output = Result<()>>,
278 M: EnvelopedMessage,
279 {
280 self.add_handler(move |envelope, session| handler(envelope.payload, session));
281 self
282 }
283
284 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
285 where
286 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
287 Fut: Send + Future<Output = Result<()>>,
288 M: RequestMessage,
289 {
290 let handler = Arc::new(handler);
291 self.add_handler(move |envelope, session| {
292 let receipt = envelope.receipt();
293 let handler = handler.clone();
294 async move {
295 let peer = session.peer.clone();
296 let responded = Arc::new(AtomicBool::default());
297 let response = Response {
298 peer: peer.clone(),
299 responded: responded.clone(),
300 receipt,
301 };
302 match (handler)(envelope.payload, response, session).await {
303 Ok(()) => {
304 if responded.load(std::sync::atomic::Ordering::SeqCst) {
305 Ok(())
306 } else {
307 Err(anyhow!("handler did not send a response"))?
308 }
309 }
310 Err(error) => {
311 peer.respond_with_error(
312 receipt,
313 proto::Error {
314 message: error.to_string(),
315 },
316 )?;
317 Err(error)
318 }
319 }
320 }
321 })
322 }
323
324 pub fn handle_connection(
325 self: &Arc<Self>,
326 connection: Connection,
327 address: String,
328 user: User,
329 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
330 executor: Executor,
331 ) -> impl Future<Output = Result<()>> {
332 let this = self.clone();
333 let user_id = user.id;
334 let login = user.github_login;
335 let span = info_span!("handle connection", %user_id, %login, %address);
336 async move {
337 let (connection_id, handle_io, mut incoming_rx) = this
338 .peer
339 .add_connection(connection, {
340 let executor = executor.clone();
341 move |duration| executor.sleep(duration)
342 });
343
344 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
345 this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
346 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
347
348 if let Some(send_connection_id) = send_connection_id.take() {
349 let _ = send_connection_id.send(connection_id);
350 }
351
352 if !user.connected_once {
353 this.peer.send(connection_id, proto::ShowContacts {})?;
354 this.app_state.db.set_user_connected_once(user_id, true).await?;
355 }
356
357 let (contacts, invite_code) = future::try_join(
358 this.app_state.db.get_contacts(user_id),
359 this.app_state.db.get_invite_code_for_user(user_id)
360 ).await?;
361
362 {
363 let mut pool = this.connection_pool.lock().await;
364 pool.add_connection(connection_id, user_id, user.admin);
365 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
366
367 if let Some((code, count)) = invite_code {
368 this.peer.send(connection_id, proto::UpdateInviteInfo {
369 url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
370 count: count as u32,
371 })?;
372 }
373 }
374
375 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
376 this.peer.send(connection_id, incoming_call)?;
377 }
378
379 let session = Session {
380 user_id,
381 connection_id,
382 db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))),
383 peer: this.peer.clone(),
384 connection_pool: this.connection_pool.clone(),
385 live_kit_client: this.app_state.live_kit_client.clone()
386 };
387 update_user_contacts(user_id, &session).await?;
388
389 let handle_io = handle_io.fuse();
390 futures::pin_mut!(handle_io);
391
392 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
393 // This prevents deadlocks when e.g., client A performs a request to client B and
394 // client B performs a request to client A. If both clients stop processing further
395 // messages until their respective request completes, they won't have a chance to
396 // respond to the other client's request and cause a deadlock.
397 //
398 // This arrangement ensures we will attempt to process earlier messages first, but fall
399 // back to processing messages arrived later in the spirit of making progress.
400 let mut foreground_message_handlers = FuturesUnordered::new();
401 loop {
402 let next_message = incoming_rx.next().fuse();
403 futures::pin_mut!(next_message);
404 futures::select_biased! {
405 result = handle_io => {
406 if let Err(error) = result {
407 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
408 }
409 break;
410 }
411 _ = foreground_message_handlers.next() => {}
412 message = next_message => {
413 if let Some(message) = message {
414 let type_name = message.payload_type_name();
415 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
416 let span_enter = span.enter();
417 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
418 let is_background = message.is_background();
419 let handle_message = (handler)(message, session.clone());
420 drop(span_enter);
421
422 let handle_message = handle_message.instrument(span);
423 if is_background {
424 executor.spawn_detached(handle_message);
425 } else {
426 foreground_message_handlers.push(handle_message);
427 }
428 } else {
429 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
430 }
431 } else {
432 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
433 break;
434 }
435 }
436 }
437 }
438
439 drop(foreground_message_handlers);
440 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
441 if let Err(error) = sign_out(session, executor).await {
442 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
443 }
444
445 Ok(())
446 }.instrument(span)
447 }
448
449 pub async fn invite_code_redeemed(
450 self: &Arc<Self>,
451 inviter_id: UserId,
452 invitee_id: UserId,
453 ) -> Result<()> {
454 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
455 if let Some(code) = &user.invite_code {
456 let pool = self.connection_pool.lock().await;
457 let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
458 for connection_id in pool.user_connection_ids(inviter_id) {
459 self.peer.send(
460 connection_id,
461 proto::UpdateContacts {
462 contacts: vec![invitee_contact.clone()],
463 ..Default::default()
464 },
465 )?;
466 self.peer.send(
467 connection_id,
468 proto::UpdateInviteInfo {
469 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
470 count: user.invite_count as u32,
471 },
472 )?;
473 }
474 }
475 }
476 Ok(())
477 }
478
479 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
480 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
481 if let Some(invite_code) = &user.invite_code {
482 let pool = self.connection_pool.lock().await;
483 for connection_id in pool.user_connection_ids(user_id) {
484 self.peer.send(
485 connection_id,
486 proto::UpdateInviteInfo {
487 url: format!(
488 "{}{}",
489 self.app_state.config.invite_link_prefix, invite_code
490 ),
491 count: user.invite_count as u32,
492 },
493 )?;
494 }
495 }
496 }
497 Ok(())
498 }
499
500 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
501 ServerSnapshot {
502 connection_pool: ConnectionPoolGuard {
503 guard: self.connection_pool.lock().await,
504 _not_send: PhantomData,
505 },
506 peer: &self.peer,
507 }
508 }
509}
510
511impl<'a> Deref for ConnectionPoolGuard<'a> {
512 type Target = ConnectionPool;
513
514 fn deref(&self) -> &Self::Target {
515 &*self.guard
516 }
517}
518
519impl<'a> DerefMut for ConnectionPoolGuard<'a> {
520 fn deref_mut(&mut self) -> &mut Self::Target {
521 &mut *self.guard
522 }
523}
524
525impl<'a> Drop for ConnectionPoolGuard<'a> {
526 fn drop(&mut self) {
527 #[cfg(test)]
528 self.check_invariants();
529 }
530}
531
532fn broadcast<F>(
533 sender_id: ConnectionId,
534 receiver_ids: impl IntoIterator<Item = ConnectionId>,
535 mut f: F,
536) where
537 F: FnMut(ConnectionId) -> anyhow::Result<()>,
538{
539 for receiver_id in receiver_ids {
540 if receiver_id != sender_id {
541 f(receiver_id).trace_err();
542 }
543 }
544}
545
546lazy_static! {
547 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
548}
549
550pub struct ProtocolVersion(u32);
551
552impl Header for ProtocolVersion {
553 fn name() -> &'static HeaderName {
554 &ZED_PROTOCOL_VERSION
555 }
556
557 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
558 where
559 Self: Sized,
560 I: Iterator<Item = &'i axum::http::HeaderValue>,
561 {
562 let version = values
563 .next()
564 .ok_or_else(axum::headers::Error::invalid)?
565 .to_str()
566 .map_err(|_| axum::headers::Error::invalid())?
567 .parse()
568 .map_err(|_| axum::headers::Error::invalid())?;
569 Ok(Self(version))
570 }
571
572 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
573 values.extend([self.0.to_string().parse().unwrap()]);
574 }
575}
576
577pub fn routes(server: Arc<Server>) -> Router<Body> {
578 Router::new()
579 .route("/rpc", get(handle_websocket_request))
580 .layer(
581 ServiceBuilder::new()
582 .layer(Extension(server.app_state.clone()))
583 .layer(middleware::from_fn(auth::validate_header)),
584 )
585 .route("/metrics", get(handle_metrics))
586 .layer(Extension(server))
587}
588
589pub async fn handle_websocket_request(
590 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
591 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
592 Extension(server): Extension<Arc<Server>>,
593 Extension(user): Extension<User>,
594 ws: WebSocketUpgrade,
595) -> axum::response::Response {
596 if protocol_version != rpc::PROTOCOL_VERSION {
597 return (
598 StatusCode::UPGRADE_REQUIRED,
599 "client must be upgraded".to_string(),
600 )
601 .into_response();
602 }
603 let socket_address = socket_address.to_string();
604 ws.on_upgrade(move |socket| {
605 use util::ResultExt;
606 let socket = socket
607 .map_ok(to_tungstenite_message)
608 .err_into()
609 .with(|message| async move { Ok(to_axum_message(message)) });
610 let connection = Connection::new(Box::pin(socket));
611 async move {
612 server
613 .handle_connection(connection, socket_address, user, None, Executor::Production)
614 .await
615 .log_err();
616 }
617 })
618}
619
620pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
621 let connections = server
622 .connection_pool
623 .lock()
624 .await
625 .connections()
626 .filter(|connection| !connection.admin)
627 .count();
628
629 METRIC_CONNECTIONS.set(connections as _);
630
631 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
632 METRIC_SHARED_PROJECTS.set(shared_projects as _);
633
634 let encoder = prometheus::TextEncoder::new();
635 let metric_families = prometheus::gather();
636 let encoded_metrics = encoder
637 .encode_to_string(&metric_families)
638 .map_err(|err| anyhow!("{}", err))?;
639 Ok(encoded_metrics)
640}
641
642#[instrument(err, skip(executor))]
643async fn sign_out(session: Session, executor: Executor) -> Result<()> {
644 session.peer.disconnect(session.connection_id);
645 session
646 .connection_pool()
647 .await
648 .remove_connection(session.connection_id)?;
649
650 if let Ok(mut left_projects) = session
651 .db()
652 .await
653 .connection_lost(session.connection_id)
654 .await
655 {
656 for left_project in mem::take(&mut *left_projects) {
657 project_left(&left_project, &session);
658 }
659 }
660
661 executor.sleep(RECONNECTION_TIMEOUT).await;
662 leave_room_for_session(&session).await.trace_err();
663
664 if !session
665 .connection_pool()
666 .await
667 .is_user_online(session.user_id)
668 {
669 let db = session.db().await;
670 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
671 room_updated(&room, &session);
672 }
673 }
674 update_user_contacts(session.user_id, &session).await?;
675
676 Ok(())
677}
678
679async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
680 response.send(proto::Ack {})?;
681 Ok(())
682}
683
684async fn create_room(
685 _request: proto::CreateRoom,
686 response: Response<proto::CreateRoom>,
687 session: Session,
688) -> Result<()> {
689 let live_kit_room = nanoid::nanoid!(30);
690 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
691 if let Some(_) = live_kit
692 .create_room(live_kit_room.clone())
693 .await
694 .trace_err()
695 {
696 if let Some(token) = live_kit
697 .room_token(&live_kit_room, &session.connection_id.to_string())
698 .trace_err()
699 {
700 Some(proto::LiveKitConnectionInfo {
701 server_url: live_kit.url().into(),
702 token,
703 })
704 } else {
705 None
706 }
707 } else {
708 None
709 }
710 } else {
711 None
712 };
713
714 {
715 let room = session
716 .db()
717 .await
718 .create_room(session.user_id, session.connection_id, &live_kit_room)
719 .await?;
720
721 response.send(proto::CreateRoomResponse {
722 room: Some(room.clone()),
723 live_kit_connection_info,
724 })?;
725 }
726
727 update_user_contacts(session.user_id, &session).await?;
728 Ok(())
729}
730
731async fn join_room(
732 request: proto::JoinRoom,
733 response: Response<proto::JoinRoom>,
734 session: Session,
735) -> Result<()> {
736 let room = {
737 let room = session
738 .db()
739 .await
740 .join_room(
741 RoomId::from_proto(request.id),
742 session.user_id,
743 session.connection_id,
744 )
745 .await?;
746 room_updated(&room, &session);
747 room.clone()
748 };
749
750 for connection_id in session
751 .connection_pool()
752 .await
753 .user_connection_ids(session.user_id)
754 {
755 session
756 .peer
757 .send(connection_id, proto::CallCanceled {})
758 .trace_err();
759 }
760
761 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
762 if let Some(token) = live_kit
763 .room_token(&room.live_kit_room, &session.connection_id.to_string())
764 .trace_err()
765 {
766 Some(proto::LiveKitConnectionInfo {
767 server_url: live_kit.url().into(),
768 token,
769 })
770 } else {
771 None
772 }
773 } else {
774 None
775 };
776
777 response.send(proto::JoinRoomResponse {
778 room: Some(room),
779 live_kit_connection_info,
780 })?;
781
782 update_user_contacts(session.user_id, &session).await?;
783 Ok(())
784}
785
786async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
787 leave_room_for_session(&session).await
788}
789
790async fn call(
791 request: proto::Call,
792 response: Response<proto::Call>,
793 session: Session,
794) -> Result<()> {
795 let room_id = RoomId::from_proto(request.room_id);
796 let calling_user_id = session.user_id;
797 let calling_connection_id = session.connection_id;
798 let called_user_id = UserId::from_proto(request.called_user_id);
799 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
800 if !session
801 .db()
802 .await
803 .has_contact(calling_user_id, called_user_id)
804 .await?
805 {
806 return Err(anyhow!("cannot call a user who isn't a contact"))?;
807 }
808
809 let incoming_call = {
810 let (room, incoming_call) = &mut *session
811 .db()
812 .await
813 .call(
814 room_id,
815 calling_user_id,
816 calling_connection_id,
817 called_user_id,
818 initial_project_id,
819 )
820 .await?;
821 room_updated(&room, &session);
822 mem::take(incoming_call)
823 };
824 update_user_contacts(called_user_id, &session).await?;
825
826 let mut calls = session
827 .connection_pool()
828 .await
829 .user_connection_ids(called_user_id)
830 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
831 .collect::<FuturesUnordered<_>>();
832
833 while let Some(call_response) = calls.next().await {
834 match call_response.as_ref() {
835 Ok(_) => {
836 response.send(proto::Ack {})?;
837 return Ok(());
838 }
839 Err(_) => {
840 call_response.trace_err();
841 }
842 }
843 }
844
845 {
846 let room = session
847 .db()
848 .await
849 .call_failed(room_id, called_user_id)
850 .await?;
851 room_updated(&room, &session);
852 }
853 update_user_contacts(called_user_id, &session).await?;
854
855 Err(anyhow!("failed to ring user"))?
856}
857
858async fn cancel_call(
859 request: proto::CancelCall,
860 response: Response<proto::CancelCall>,
861 session: Session,
862) -> Result<()> {
863 let called_user_id = UserId::from_proto(request.called_user_id);
864 let room_id = RoomId::from_proto(request.room_id);
865 {
866 let room = session
867 .db()
868 .await
869 .cancel_call(Some(room_id), session.connection_id, called_user_id)
870 .await?;
871 room_updated(&room, &session);
872 }
873
874 for connection_id in session
875 .connection_pool()
876 .await
877 .user_connection_ids(called_user_id)
878 {
879 session
880 .peer
881 .send(connection_id, proto::CallCanceled {})
882 .trace_err();
883 }
884 response.send(proto::Ack {})?;
885
886 update_user_contacts(called_user_id, &session).await?;
887 Ok(())
888}
889
890async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
891 let room_id = RoomId::from_proto(message.room_id);
892 {
893 let room = session
894 .db()
895 .await
896 .decline_call(Some(room_id), session.user_id)
897 .await?;
898 room_updated(&room, &session);
899 }
900
901 for connection_id in session
902 .connection_pool()
903 .await
904 .user_connection_ids(session.user_id)
905 {
906 session
907 .peer
908 .send(connection_id, proto::CallCanceled {})
909 .trace_err();
910 }
911 update_user_contacts(session.user_id, &session).await?;
912 Ok(())
913}
914
915async fn update_participant_location(
916 request: proto::UpdateParticipantLocation,
917 response: Response<proto::UpdateParticipantLocation>,
918 session: Session,
919) -> Result<()> {
920 let room_id = RoomId::from_proto(request.room_id);
921 let location = request
922 .location
923 .ok_or_else(|| anyhow!("invalid location"))?;
924 let room = session
925 .db()
926 .await
927 .update_room_participant_location(room_id, session.connection_id, location)
928 .await?;
929 room_updated(&room, &session);
930 response.send(proto::Ack {})?;
931 Ok(())
932}
933
934async fn share_project(
935 request: proto::ShareProject,
936 response: Response<proto::ShareProject>,
937 session: Session,
938) -> Result<()> {
939 let (project_id, room) = &*session
940 .db()
941 .await
942 .share_project(
943 RoomId::from_proto(request.room_id),
944 session.connection_id,
945 &request.worktrees,
946 )
947 .await?;
948 response.send(proto::ShareProjectResponse {
949 project_id: project_id.to_proto(),
950 })?;
951 room_updated(&room, &session);
952
953 Ok(())
954}
955
956async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
957 let project_id = ProjectId::from_proto(message.project_id);
958
959 let (room, guest_connection_ids) = &*session
960 .db()
961 .await
962 .unshare_project(project_id, session.connection_id)
963 .await?;
964
965 broadcast(
966 session.connection_id,
967 guest_connection_ids.iter().copied(),
968 |conn_id| session.peer.send(conn_id, message.clone()),
969 );
970 room_updated(&room, &session);
971
972 Ok(())
973}
974
975async fn join_project(
976 request: proto::JoinProject,
977 response: Response<proto::JoinProject>,
978 session: Session,
979) -> Result<()> {
980 let project_id = ProjectId::from_proto(request.project_id);
981 let guest_user_id = session.user_id;
982
983 tracing::info!(%project_id, "join project");
984
985 let (project, replica_id) = &mut *session
986 .db()
987 .await
988 .join_project(project_id, session.connection_id)
989 .await?;
990
991 let collaborators = project
992 .collaborators
993 .iter()
994 .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
995 .map(|collaborator| proto::Collaborator {
996 peer_id: collaborator.connection_id as u32,
997 replica_id: collaborator.replica_id.0 as u32,
998 user_id: collaborator.user_id.to_proto(),
999 })
1000 .collect::<Vec<_>>();
1001 let worktrees = project
1002 .worktrees
1003 .iter()
1004 .map(|(id, worktree)| proto::WorktreeMetadata {
1005 id: *id,
1006 root_name: worktree.root_name.clone(),
1007 visible: worktree.visible,
1008 abs_path: worktree.abs_path.clone(),
1009 })
1010 .collect::<Vec<_>>();
1011
1012 for collaborator in &collaborators {
1013 session
1014 .peer
1015 .send(
1016 ConnectionId(collaborator.peer_id),
1017 proto::AddProjectCollaborator {
1018 project_id: project_id.to_proto(),
1019 collaborator: Some(proto::Collaborator {
1020 peer_id: session.connection_id.0,
1021 replica_id: replica_id.0 as u32,
1022 user_id: guest_user_id.to_proto(),
1023 }),
1024 },
1025 )
1026 .trace_err();
1027 }
1028
1029 // First, we send the metadata associated with each worktree.
1030 response.send(proto::JoinProjectResponse {
1031 worktrees: worktrees.clone(),
1032 replica_id: replica_id.0 as u32,
1033 collaborators: collaborators.clone(),
1034 language_servers: project.language_servers.clone(),
1035 })?;
1036
1037 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1038 #[cfg(any(test, feature = "test-support"))]
1039 const MAX_CHUNK_SIZE: usize = 2;
1040 #[cfg(not(any(test, feature = "test-support")))]
1041 const MAX_CHUNK_SIZE: usize = 256;
1042
1043 // Stream this worktree's entries.
1044 let message = proto::UpdateWorktree {
1045 project_id: project_id.to_proto(),
1046 worktree_id,
1047 abs_path: worktree.abs_path.clone(),
1048 root_name: worktree.root_name,
1049 updated_entries: worktree.entries,
1050 removed_entries: Default::default(),
1051 scan_id: worktree.scan_id,
1052 is_last_update: worktree.is_complete,
1053 };
1054 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1055 session.peer.send(session.connection_id, update.clone())?;
1056 }
1057
1058 // Stream this worktree's diagnostics.
1059 for summary in worktree.diagnostic_summaries {
1060 session.peer.send(
1061 session.connection_id,
1062 proto::UpdateDiagnosticSummary {
1063 project_id: project_id.to_proto(),
1064 worktree_id: worktree.id,
1065 summary: Some(summary),
1066 },
1067 )?;
1068 }
1069 }
1070
1071 for language_server in &project.language_servers {
1072 session.peer.send(
1073 session.connection_id,
1074 proto::UpdateLanguageServer {
1075 project_id: project_id.to_proto(),
1076 language_server_id: language_server.id,
1077 variant: Some(
1078 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1079 proto::LspDiskBasedDiagnosticsUpdated {},
1080 ),
1081 ),
1082 },
1083 )?;
1084 }
1085
1086 Ok(())
1087}
1088
1089async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1090 let sender_id = session.connection_id;
1091 let project_id = ProjectId::from_proto(request.project_id);
1092
1093 let project = session
1094 .db()
1095 .await
1096 .leave_project(project_id, sender_id)
1097 .await?;
1098 tracing::info!(
1099 %project_id,
1100 host_user_id = %project.host_user_id,
1101 host_connection_id = %project.host_connection_id,
1102 "leave project"
1103 );
1104 project_left(&project, &session);
1105
1106 Ok(())
1107}
1108
1109async fn update_project(
1110 request: proto::UpdateProject,
1111 response: Response<proto::UpdateProject>,
1112 session: Session,
1113) -> Result<()> {
1114 let project_id = ProjectId::from_proto(request.project_id);
1115 let (room, guest_connection_ids) = &*session
1116 .db()
1117 .await
1118 .update_project(project_id, session.connection_id, &request.worktrees)
1119 .await?;
1120 broadcast(
1121 session.connection_id,
1122 guest_connection_ids.iter().copied(),
1123 |connection_id| {
1124 session
1125 .peer
1126 .forward_send(session.connection_id, connection_id, request.clone())
1127 },
1128 );
1129 room_updated(&room, &session);
1130 response.send(proto::Ack {})?;
1131
1132 Ok(())
1133}
1134
1135async fn update_worktree(
1136 request: proto::UpdateWorktree,
1137 response: Response<proto::UpdateWorktree>,
1138 session: Session,
1139) -> Result<()> {
1140 let guest_connection_ids = session
1141 .db()
1142 .await
1143 .update_worktree(&request, session.connection_id)
1144 .await?;
1145
1146 broadcast(
1147 session.connection_id,
1148 guest_connection_ids.iter().copied(),
1149 |connection_id| {
1150 session
1151 .peer
1152 .forward_send(session.connection_id, connection_id, request.clone())
1153 },
1154 );
1155 response.send(proto::Ack {})?;
1156 Ok(())
1157}
1158
1159async fn update_diagnostic_summary(
1160 message: proto::UpdateDiagnosticSummary,
1161 session: Session,
1162) -> Result<()> {
1163 let guest_connection_ids = session
1164 .db()
1165 .await
1166 .update_diagnostic_summary(&message, session.connection_id)
1167 .await?;
1168
1169 broadcast(
1170 session.connection_id,
1171 guest_connection_ids.iter().copied(),
1172 |connection_id| {
1173 session
1174 .peer
1175 .forward_send(session.connection_id, connection_id, message.clone())
1176 },
1177 );
1178
1179 Ok(())
1180}
1181
1182async fn start_language_server(
1183 request: proto::StartLanguageServer,
1184 session: Session,
1185) -> Result<()> {
1186 let guest_connection_ids = session
1187 .db()
1188 .await
1189 .start_language_server(&request, session.connection_id)
1190 .await?;
1191
1192 broadcast(
1193 session.connection_id,
1194 guest_connection_ids.iter().copied(),
1195 |connection_id| {
1196 session
1197 .peer
1198 .forward_send(session.connection_id, connection_id, request.clone())
1199 },
1200 );
1201 Ok(())
1202}
1203
1204async fn update_language_server(
1205 request: proto::UpdateLanguageServer,
1206 session: Session,
1207) -> Result<()> {
1208 let project_id = ProjectId::from_proto(request.project_id);
1209 let project_connection_ids = session
1210 .db()
1211 .await
1212 .project_connection_ids(project_id, session.connection_id)
1213 .await?;
1214 broadcast(
1215 session.connection_id,
1216 project_connection_ids.iter().copied(),
1217 |connection_id| {
1218 session
1219 .peer
1220 .forward_send(session.connection_id, connection_id, request.clone())
1221 },
1222 );
1223 Ok(())
1224}
1225
1226async fn forward_project_request<T>(
1227 request: T,
1228 response: Response<T>,
1229 session: Session,
1230) -> Result<()>
1231where
1232 T: EntityMessage + RequestMessage,
1233{
1234 let project_id = ProjectId::from_proto(request.remote_entity_id());
1235 let host_connection_id = {
1236 let collaborators = session
1237 .db()
1238 .await
1239 .project_collaborators(project_id, session.connection_id)
1240 .await?;
1241 ConnectionId(
1242 collaborators
1243 .iter()
1244 .find(|collaborator| collaborator.is_host)
1245 .ok_or_else(|| anyhow!("host not found"))?
1246 .connection_id as u32,
1247 )
1248 };
1249
1250 let payload = session
1251 .peer
1252 .forward_request(session.connection_id, host_connection_id, request)
1253 .await?;
1254
1255 response.send(payload)?;
1256 Ok(())
1257}
1258
1259async fn save_buffer(
1260 request: proto::SaveBuffer,
1261 response: Response<proto::SaveBuffer>,
1262 session: Session,
1263) -> Result<()> {
1264 let project_id = ProjectId::from_proto(request.project_id);
1265 let host_connection_id = {
1266 let collaborators = session
1267 .db()
1268 .await
1269 .project_collaborators(project_id, session.connection_id)
1270 .await?;
1271 let host = collaborators
1272 .iter()
1273 .find(|collaborator| collaborator.is_host)
1274 .ok_or_else(|| anyhow!("host not found"))?;
1275 ConnectionId(host.connection_id as u32)
1276 };
1277 let response_payload = session
1278 .peer
1279 .forward_request(session.connection_id, host_connection_id, request.clone())
1280 .await?;
1281
1282 let mut collaborators = session
1283 .db()
1284 .await
1285 .project_collaborators(project_id, session.connection_id)
1286 .await?;
1287 collaborators
1288 .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1289 let project_connection_ids = collaborators
1290 .iter()
1291 .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1292 broadcast(host_connection_id, project_connection_ids, |conn_id| {
1293 session
1294 .peer
1295 .forward_send(host_connection_id, conn_id, response_payload.clone())
1296 });
1297 response.send(response_payload)?;
1298 Ok(())
1299}
1300
1301async fn create_buffer_for_peer(
1302 request: proto::CreateBufferForPeer,
1303 session: Session,
1304) -> Result<()> {
1305 session.peer.forward_send(
1306 session.connection_id,
1307 ConnectionId(request.peer_id),
1308 request,
1309 )?;
1310 Ok(())
1311}
1312
1313async fn update_buffer(
1314 request: proto::UpdateBuffer,
1315 response: Response<proto::UpdateBuffer>,
1316 session: Session,
1317) -> Result<()> {
1318 let project_id = ProjectId::from_proto(request.project_id);
1319 let project_connection_ids = session
1320 .db()
1321 .await
1322 .project_connection_ids(project_id, session.connection_id)
1323 .await?;
1324
1325 broadcast(
1326 session.connection_id,
1327 project_connection_ids.iter().copied(),
1328 |connection_id| {
1329 session
1330 .peer
1331 .forward_send(session.connection_id, connection_id, request.clone())
1332 },
1333 );
1334 response.send(proto::Ack {})?;
1335 Ok(())
1336}
1337
1338async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1339 let project_id = ProjectId::from_proto(request.project_id);
1340 let project_connection_ids = session
1341 .db()
1342 .await
1343 .project_connection_ids(project_id, session.connection_id)
1344 .await?;
1345
1346 broadcast(
1347 session.connection_id,
1348 project_connection_ids.iter().copied(),
1349 |connection_id| {
1350 session
1351 .peer
1352 .forward_send(session.connection_id, connection_id, request.clone())
1353 },
1354 );
1355 Ok(())
1356}
1357
1358async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1359 let project_id = ProjectId::from_proto(request.project_id);
1360 let project_connection_ids = session
1361 .db()
1362 .await
1363 .project_connection_ids(project_id, session.connection_id)
1364 .await?;
1365 broadcast(
1366 session.connection_id,
1367 project_connection_ids.iter().copied(),
1368 |connection_id| {
1369 session
1370 .peer
1371 .forward_send(session.connection_id, connection_id, request.clone())
1372 },
1373 );
1374 Ok(())
1375}
1376
1377async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1378 let project_id = ProjectId::from_proto(request.project_id);
1379 let project_connection_ids = session
1380 .db()
1381 .await
1382 .project_connection_ids(project_id, session.connection_id)
1383 .await?;
1384 broadcast(
1385 session.connection_id,
1386 project_connection_ids.iter().copied(),
1387 |connection_id| {
1388 session
1389 .peer
1390 .forward_send(session.connection_id, connection_id, request.clone())
1391 },
1392 );
1393 Ok(())
1394}
1395
1396async fn follow(
1397 request: proto::Follow,
1398 response: Response<proto::Follow>,
1399 session: Session,
1400) -> Result<()> {
1401 let project_id = ProjectId::from_proto(request.project_id);
1402 let leader_id = ConnectionId(request.leader_id);
1403 let follower_id = session.connection_id;
1404 {
1405 let project_connection_ids = session
1406 .db()
1407 .await
1408 .project_connection_ids(project_id, session.connection_id)
1409 .await?;
1410
1411 if !project_connection_ids.contains(&leader_id) {
1412 Err(anyhow!("no such peer"))?;
1413 }
1414 }
1415
1416 let mut response_payload = session
1417 .peer
1418 .forward_request(session.connection_id, leader_id, request)
1419 .await?;
1420 response_payload
1421 .views
1422 .retain(|view| view.leader_id != Some(follower_id.0));
1423 response.send(response_payload)?;
1424 Ok(())
1425}
1426
1427async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1428 let project_id = ProjectId::from_proto(request.project_id);
1429 let leader_id = ConnectionId(request.leader_id);
1430 let project_connection_ids = session
1431 .db()
1432 .await
1433 .project_connection_ids(project_id, session.connection_id)
1434 .await?;
1435 if !project_connection_ids.contains(&leader_id) {
1436 Err(anyhow!("no such peer"))?;
1437 }
1438 session
1439 .peer
1440 .forward_send(session.connection_id, leader_id, request)?;
1441 Ok(())
1442}
1443
1444async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1445 let project_id = ProjectId::from_proto(request.project_id);
1446 let project_connection_ids = session
1447 .db
1448 .lock()
1449 .await
1450 .project_connection_ids(project_id, session.connection_id)
1451 .await?;
1452
1453 let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1454 proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1455 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1456 proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1457 });
1458 for follower_id in &request.follower_ids {
1459 let follower_id = ConnectionId(*follower_id);
1460 if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1461 session
1462 .peer
1463 .forward_send(session.connection_id, follower_id, request.clone())?;
1464 }
1465 }
1466 Ok(())
1467}
1468
1469async fn get_users(
1470 request: proto::GetUsers,
1471 response: Response<proto::GetUsers>,
1472 session: Session,
1473) -> Result<()> {
1474 let user_ids = request
1475 .user_ids
1476 .into_iter()
1477 .map(UserId::from_proto)
1478 .collect();
1479 let users = session
1480 .db()
1481 .await
1482 .get_users_by_ids(user_ids)
1483 .await?
1484 .into_iter()
1485 .map(|user| proto::User {
1486 id: user.id.to_proto(),
1487 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1488 github_login: user.github_login,
1489 })
1490 .collect();
1491 response.send(proto::UsersResponse { users })?;
1492 Ok(())
1493}
1494
1495async fn fuzzy_search_users(
1496 request: proto::FuzzySearchUsers,
1497 response: Response<proto::FuzzySearchUsers>,
1498 session: Session,
1499) -> Result<()> {
1500 let query = request.query;
1501 let users = match query.len() {
1502 0 => vec![],
1503 1 | 2 => session
1504 .db()
1505 .await
1506 .get_user_by_github_account(&query, None)
1507 .await?
1508 .into_iter()
1509 .collect(),
1510 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1511 };
1512 let users = users
1513 .into_iter()
1514 .filter(|user| user.id != session.user_id)
1515 .map(|user| proto::User {
1516 id: user.id.to_proto(),
1517 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1518 github_login: user.github_login,
1519 })
1520 .collect();
1521 response.send(proto::UsersResponse { users })?;
1522 Ok(())
1523}
1524
1525async fn request_contact(
1526 request: proto::RequestContact,
1527 response: Response<proto::RequestContact>,
1528 session: Session,
1529) -> Result<()> {
1530 let requester_id = session.user_id;
1531 let responder_id = UserId::from_proto(request.responder_id);
1532 if requester_id == responder_id {
1533 return Err(anyhow!("cannot add yourself as a contact"))?;
1534 }
1535
1536 session
1537 .db()
1538 .await
1539 .send_contact_request(requester_id, responder_id)
1540 .await?;
1541
1542 // Update outgoing contact requests of requester
1543 let mut update = proto::UpdateContacts::default();
1544 update.outgoing_requests.push(responder_id.to_proto());
1545 for connection_id in session
1546 .connection_pool()
1547 .await
1548 .user_connection_ids(requester_id)
1549 {
1550 session.peer.send(connection_id, update.clone())?;
1551 }
1552
1553 // Update incoming contact requests of responder
1554 let mut update = proto::UpdateContacts::default();
1555 update
1556 .incoming_requests
1557 .push(proto::IncomingContactRequest {
1558 requester_id: requester_id.to_proto(),
1559 should_notify: true,
1560 });
1561 for connection_id in session
1562 .connection_pool()
1563 .await
1564 .user_connection_ids(responder_id)
1565 {
1566 session.peer.send(connection_id, update.clone())?;
1567 }
1568
1569 response.send(proto::Ack {})?;
1570 Ok(())
1571}
1572
1573async fn respond_to_contact_request(
1574 request: proto::RespondToContactRequest,
1575 response: Response<proto::RespondToContactRequest>,
1576 session: Session,
1577) -> Result<()> {
1578 let responder_id = session.user_id;
1579 let requester_id = UserId::from_proto(request.requester_id);
1580 let db = session.db().await;
1581 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1582 db.dismiss_contact_notification(responder_id, requester_id)
1583 .await?;
1584 } else {
1585 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1586
1587 db.respond_to_contact_request(responder_id, requester_id, accept)
1588 .await?;
1589 let requester_busy = db.is_user_busy(requester_id).await?;
1590 let responder_busy = db.is_user_busy(responder_id).await?;
1591
1592 let pool = session.connection_pool().await;
1593 // Update responder with new contact
1594 let mut update = proto::UpdateContacts::default();
1595 if accept {
1596 update
1597 .contacts
1598 .push(contact_for_user(requester_id, false, requester_busy, &pool));
1599 }
1600 update
1601 .remove_incoming_requests
1602 .push(requester_id.to_proto());
1603 for connection_id in pool.user_connection_ids(responder_id) {
1604 session.peer.send(connection_id, update.clone())?;
1605 }
1606
1607 // Update requester with new contact
1608 let mut update = proto::UpdateContacts::default();
1609 if accept {
1610 update
1611 .contacts
1612 .push(contact_for_user(responder_id, true, responder_busy, &pool));
1613 }
1614 update
1615 .remove_outgoing_requests
1616 .push(responder_id.to_proto());
1617 for connection_id in pool.user_connection_ids(requester_id) {
1618 session.peer.send(connection_id, update.clone())?;
1619 }
1620 }
1621
1622 response.send(proto::Ack {})?;
1623 Ok(())
1624}
1625
1626async fn remove_contact(
1627 request: proto::RemoveContact,
1628 response: Response<proto::RemoveContact>,
1629 session: Session,
1630) -> Result<()> {
1631 let requester_id = session.user_id;
1632 let responder_id = UserId::from_proto(request.user_id);
1633 let db = session.db().await;
1634 db.remove_contact(requester_id, responder_id).await?;
1635
1636 let pool = session.connection_pool().await;
1637 // Update outgoing contact requests of requester
1638 let mut update = proto::UpdateContacts::default();
1639 update
1640 .remove_outgoing_requests
1641 .push(responder_id.to_proto());
1642 for connection_id in pool.user_connection_ids(requester_id) {
1643 session.peer.send(connection_id, update.clone())?;
1644 }
1645
1646 // Update incoming contact requests of responder
1647 let mut update = proto::UpdateContacts::default();
1648 update
1649 .remove_incoming_requests
1650 .push(requester_id.to_proto());
1651 for connection_id in pool.user_connection_ids(responder_id) {
1652 session.peer.send(connection_id, update.clone())?;
1653 }
1654
1655 response.send(proto::Ack {})?;
1656 Ok(())
1657}
1658
1659async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1660 let project_id = ProjectId::from_proto(request.project_id);
1661 let project_connection_ids = session
1662 .db()
1663 .await
1664 .project_connection_ids(project_id, session.connection_id)
1665 .await?;
1666 broadcast(
1667 session.connection_id,
1668 project_connection_ids.iter().copied(),
1669 |connection_id| {
1670 session
1671 .peer
1672 .forward_send(session.connection_id, connection_id, request.clone())
1673 },
1674 );
1675 Ok(())
1676}
1677
1678async fn get_private_user_info(
1679 _request: proto::GetPrivateUserInfo,
1680 response: Response<proto::GetPrivateUserInfo>,
1681 session: Session,
1682) -> Result<()> {
1683 let metrics_id = session
1684 .db()
1685 .await
1686 .get_user_metrics_id(session.user_id)
1687 .await?;
1688 let user = session
1689 .db()
1690 .await
1691 .get_user_by_id(session.user_id)
1692 .await?
1693 .ok_or_else(|| anyhow!("user not found"))?;
1694 response.send(proto::GetPrivateUserInfoResponse {
1695 metrics_id,
1696 staff: user.admin,
1697 })?;
1698 Ok(())
1699}
1700
1701fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1702 match message {
1703 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1704 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1705 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1706 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1707 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1708 code: frame.code.into(),
1709 reason: frame.reason,
1710 })),
1711 }
1712}
1713
1714fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1715 match message {
1716 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1717 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1718 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1719 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1720 AxumMessage::Close(frame) => {
1721 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1722 code: frame.code.into(),
1723 reason: frame.reason,
1724 }))
1725 }
1726 }
1727}
1728
1729fn build_initial_contacts_update(
1730 contacts: Vec<db::Contact>,
1731 pool: &ConnectionPool,
1732) -> proto::UpdateContacts {
1733 let mut update = proto::UpdateContacts::default();
1734
1735 for contact in contacts {
1736 match contact {
1737 db::Contact::Accepted {
1738 user_id,
1739 should_notify,
1740 busy,
1741 } => {
1742 update
1743 .contacts
1744 .push(contact_for_user(user_id, should_notify, busy, &pool));
1745 }
1746 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1747 db::Contact::Incoming {
1748 user_id,
1749 should_notify,
1750 } => update
1751 .incoming_requests
1752 .push(proto::IncomingContactRequest {
1753 requester_id: user_id.to_proto(),
1754 should_notify,
1755 }),
1756 }
1757 }
1758
1759 update
1760}
1761
1762fn contact_for_user(
1763 user_id: UserId,
1764 should_notify: bool,
1765 busy: bool,
1766 pool: &ConnectionPool,
1767) -> proto::Contact {
1768 proto::Contact {
1769 user_id: user_id.to_proto(),
1770 online: pool.is_user_online(user_id),
1771 busy,
1772 should_notify,
1773 }
1774}
1775
1776fn room_updated(room: &proto::Room, session: &Session) {
1777 for participant in &room.participants {
1778 session
1779 .peer
1780 .send(
1781 ConnectionId(participant.peer_id),
1782 proto::RoomUpdated {
1783 room: Some(room.clone()),
1784 },
1785 )
1786 .trace_err();
1787 }
1788}
1789
1790async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1791 let db = session.db().await;
1792 let contacts = db.get_contacts(user_id).await?;
1793 let busy = db.is_user_busy(user_id).await?;
1794
1795 let pool = session.connection_pool().await;
1796 let updated_contact = contact_for_user(user_id, false, busy, &pool);
1797 for contact in contacts {
1798 if let db::Contact::Accepted {
1799 user_id: contact_user_id,
1800 ..
1801 } = contact
1802 {
1803 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1804 session
1805 .peer
1806 .send(
1807 contact_conn_id,
1808 proto::UpdateContacts {
1809 contacts: vec![updated_contact.clone()],
1810 remove_contacts: Default::default(),
1811 incoming_requests: Default::default(),
1812 remove_incoming_requests: Default::default(),
1813 outgoing_requests: Default::default(),
1814 remove_outgoing_requests: Default::default(),
1815 },
1816 )
1817 .trace_err();
1818 }
1819 }
1820 }
1821 Ok(())
1822}
1823
1824async fn leave_room_for_session(session: &Session) -> Result<()> {
1825 let mut contacts_to_update = HashSet::default();
1826
1827 let canceled_calls_to_user_ids;
1828 let live_kit_room;
1829 let delete_live_kit_room;
1830 {
1831 let mut left_room = session.db().await.leave_room(session.connection_id).await?;
1832 contacts_to_update.insert(session.user_id);
1833
1834 for project in left_room.left_projects.values() {
1835 project_left(project, session);
1836 }
1837
1838 room_updated(&left_room.room, &session);
1839 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
1840 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
1841 delete_live_kit_room = left_room.room.participants.is_empty();
1842 }
1843
1844 {
1845 let pool = session.connection_pool().await;
1846 for canceled_user_id in canceled_calls_to_user_ids {
1847 for connection_id in pool.user_connection_ids(canceled_user_id) {
1848 session
1849 .peer
1850 .send(connection_id, proto::CallCanceled {})
1851 .trace_err();
1852 }
1853 contacts_to_update.insert(canceled_user_id);
1854 }
1855 }
1856
1857 for contact_user_id in contacts_to_update {
1858 update_user_contacts(contact_user_id, &session).await?;
1859 }
1860
1861 if let Some(live_kit) = session.live_kit_client.as_ref() {
1862 live_kit
1863 .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
1864 .await
1865 .trace_err();
1866
1867 if delete_live_kit_room {
1868 live_kit.delete_room(live_kit_room).await.trace_err();
1869 }
1870 }
1871
1872 Ok(())
1873}
1874
1875fn project_left(project: &db::LeftProject, session: &Session) {
1876 for connection_id in &project.connection_ids {
1877 if project.host_user_id == session.user_id {
1878 session
1879 .peer
1880 .send(
1881 *connection_id,
1882 proto::UnshareProject {
1883 project_id: project.id.to_proto(),
1884 },
1885 )
1886 .trace_err();
1887 } else {
1888 session
1889 .peer
1890 .send(
1891 *connection_id,
1892 proto::RemoveProjectCollaborator {
1893 project_id: project.id.to_proto(),
1894 peer_id: session.connection_id.0,
1895 },
1896 )
1897 .trace_err();
1898 }
1899 }
1900
1901 session
1902 .peer
1903 .send(
1904 session.connection_id,
1905 proto::UnshareProject {
1906 project_id: project.id.to_proto(),
1907 },
1908 )
1909 .trace_err();
1910}
1911
1912pub trait ResultExt {
1913 type Ok;
1914
1915 fn trace_err(self) -> Option<Self::Ok>;
1916}
1917
1918impl<T, E> ResultExt for Result<T, E>
1919where
1920 E: std::fmt::Debug,
1921{
1922 type Ok = T;
1923
1924 fn trace_err(self) -> Option<T> {
1925 match self {
1926 Ok(value) => Some(value),
1927 Err(error) => {
1928 tracing::error!("{:?}", error);
1929 None
1930 }
1931 }
1932 }
1933}