rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, Database, ProjectId, RoomId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  38    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  39};
  40use serde::{Serialize, Serializer};
  41use std::{
  42    any::TypeId,
  43    fmt,
  44    future::Future,
  45    marker::PhantomData,
  46    mem,
  47    net::SocketAddr,
  48    ops::{Deref, DerefMut},
  49    rc::Rc,
  50    sync::{
  51        atomic::{AtomicBool, Ordering::SeqCst},
  52        Arc,
  53    },
  54    time::Duration,
  55};
  56use tokio::sync::watch;
  57use tower::ServiceBuilder;
  58use tracing::{info_span, instrument, Instrument};
  59
  60pub const RECONNECT_TIMEOUT: Duration = rpc::RECEIVE_TIMEOUT;
  61
  62lazy_static! {
  63    static ref METRIC_CONNECTIONS: IntGauge =
  64        register_int_gauge!("connections", "number of connections").unwrap();
  65    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  66        "shared_projects",
  67        "number of open projects with one or more guests"
  68    )
  69    .unwrap();
  70}
  71
  72type MessageHandler =
  73    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  74
  75struct Response<R> {
  76    peer: Arc<Peer>,
  77    receipt: Receipt<R>,
  78    responded: Arc<AtomicBool>,
  79}
  80
  81impl<R: RequestMessage> Response<R> {
  82    fn send(self, payload: R::Response) -> Result<()> {
  83        self.responded.store(true, SeqCst);
  84        self.peer.respond(self.receipt, payload)?;
  85        Ok(())
  86    }
  87}
  88
  89#[derive(Clone)]
  90struct Session {
  91    user_id: UserId,
  92    connection_id: ConnectionId,
  93    db: Arc<tokio::sync::Mutex<DbHandle>>,
  94    peer: Arc<Peer>,
  95    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
  96    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
  97}
  98
  99impl Session {
 100    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 101        #[cfg(test)]
 102        tokio::task::yield_now().await;
 103        let guard = self.db.lock().await;
 104        #[cfg(test)]
 105        tokio::task::yield_now().await;
 106        guard
 107    }
 108
 109    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 110        #[cfg(test)]
 111        tokio::task::yield_now().await;
 112        let guard = self.connection_pool.lock();
 113        ConnectionPoolGuard {
 114            guard,
 115            _not_send: PhantomData,
 116        }
 117    }
 118}
 119
 120impl fmt::Debug for Session {
 121    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 122        f.debug_struct("Session")
 123            .field("user_id", &self.user_id)
 124            .field("connection_id", &self.connection_id)
 125            .finish()
 126    }
 127}
 128
 129struct DbHandle(Arc<Database>);
 130
 131impl Deref for DbHandle {
 132    type Target = Database;
 133
 134    fn deref(&self) -> &Self::Target {
 135        self.0.as_ref()
 136    }
 137}
 138
 139pub struct Server {
 140    peer: Arc<Peer>,
 141    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 142    app_state: Arc<AppState>,
 143    executor: Executor,
 144    handlers: HashMap<TypeId, MessageHandler>,
 145    teardown: watch::Sender<()>,
 146}
 147
 148pub(crate) struct ConnectionPoolGuard<'a> {
 149    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 150    _not_send: PhantomData<Rc<()>>,
 151}
 152
 153#[derive(Serialize)]
 154pub struct ServerSnapshot<'a> {
 155    peer: &'a Peer,
 156    #[serde(serialize_with = "serialize_deref")]
 157    connection_pool: ConnectionPoolGuard<'a>,
 158}
 159
 160pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 161where
 162    S: Serializer,
 163    T: Deref<Target = U>,
 164    U: Serialize,
 165{
 166    Serialize::serialize(value.deref(), serializer)
 167}
 168
 169impl Server {
 170    pub fn new(app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
 171        let mut server = Self {
 172            peer: Peer::new(),
 173            app_state,
 174            executor,
 175            connection_pool: Default::default(),
 176            handlers: Default::default(),
 177            teardown: watch::channel(()).0,
 178        };
 179
 180        server
 181            .add_request_handler(ping)
 182            .add_request_handler(create_room)
 183            .add_request_handler(join_room)
 184            .add_message_handler(leave_room)
 185            .add_request_handler(call)
 186            .add_request_handler(cancel_call)
 187            .add_message_handler(decline_call)
 188            .add_request_handler(update_participant_location)
 189            .add_request_handler(share_project)
 190            .add_message_handler(unshare_project)
 191            .add_request_handler(join_project)
 192            .add_message_handler(leave_project)
 193            .add_request_handler(update_project)
 194            .add_request_handler(update_worktree)
 195            .add_message_handler(start_language_server)
 196            .add_message_handler(update_language_server)
 197            .add_message_handler(update_diagnostic_summary)
 198            .add_request_handler(forward_project_request::<proto::GetHover>)
 199            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 200            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 201            .add_request_handler(forward_project_request::<proto::GetReferences>)
 202            .add_request_handler(forward_project_request::<proto::SearchProject>)
 203            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 204            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 205            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 206            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 207            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 208            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 209            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 210            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 211            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 212            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 213            .add_request_handler(forward_project_request::<proto::PerformRename>)
 214            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 215            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 216            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 217            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 218            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 219            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 220            .add_message_handler(create_buffer_for_peer)
 221            .add_request_handler(update_buffer)
 222            .add_message_handler(update_buffer_file)
 223            .add_message_handler(buffer_reloaded)
 224            .add_message_handler(buffer_saved)
 225            .add_request_handler(save_buffer)
 226            .add_request_handler(get_users)
 227            .add_request_handler(fuzzy_search_users)
 228            .add_request_handler(request_contact)
 229            .add_request_handler(remove_contact)
 230            .add_request_handler(respond_to_contact_request)
 231            .add_request_handler(follow)
 232            .add_message_handler(unfollow)
 233            .add_message_handler(update_followers)
 234            .add_message_handler(update_diff_base)
 235            .add_request_handler(get_private_user_info);
 236
 237        Arc::new(server)
 238    }
 239
 240    pub async fn start(&self) -> Result<()> {
 241        self.app_state.db.delete_stale_projects().await?;
 242        let db = self.app_state.db.clone();
 243        let peer = self.peer.clone();
 244        let timeout = self.executor.sleep(RECONNECT_TIMEOUT);
 245        let pool = self.connection_pool.clone();
 246        let live_kit_client = self.app_state.live_kit_client.clone();
 247        self.executor.spawn_detached(async move {
 248            timeout.await;
 249            if let Some(room_ids) = db.outdated_room_ids().await.trace_err() {
 250                for room_id in room_ids {
 251                    let mut contacts_to_update = HashSet::default();
 252                    let mut canceled_calls_to_user_ids = Vec::new();
 253                    let mut live_kit_room = String::new();
 254                    let mut delete_live_kit_room = false;
 255
 256                    if let Ok(mut refreshed_room) = db.refresh_room(room_id).await {
 257                        room_updated(&refreshed_room.room, &peer);
 258                        contacts_to_update
 259                            .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 260                        contacts_to_update
 261                            .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 262                        canceled_calls_to_user_ids =
 263                            mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 264                        live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 265                        delete_live_kit_room = refreshed_room.room.participants.is_empty();
 266                    }
 267
 268                    {
 269                        let pool = pool.lock();
 270                        for canceled_user_id in canceled_calls_to_user_ids {
 271                            for connection_id in pool.user_connection_ids(canceled_user_id) {
 272                                peer.send(
 273                                    connection_id,
 274                                    proto::CallCanceled {
 275                                        room_id: room_id.to_proto(),
 276                                    },
 277                                )
 278                                .trace_err();
 279                            }
 280                        }
 281                    }
 282
 283                    for user_id in contacts_to_update {
 284                        let busy = db.is_user_busy(user_id).await.trace_err();
 285                        let contacts = db.get_contacts(user_id).await.trace_err();
 286                        if let Some((busy, contacts)) = busy.zip(contacts) {
 287                            let pool = pool.lock();
 288                            let updated_contact = contact_for_user(user_id, false, busy, &pool);
 289                            for contact in contacts {
 290                                if let db::Contact::Accepted {
 291                                    user_id: contact_user_id,
 292                                    ..
 293                                } = contact
 294                                {
 295                                    for contact_conn_id in pool.user_connection_ids(contact_user_id)
 296                                    {
 297                                        peer.send(
 298                                            contact_conn_id,
 299                                            proto::UpdateContacts {
 300                                                contacts: vec![updated_contact.clone()],
 301                                                remove_contacts: Default::default(),
 302                                                incoming_requests: Default::default(),
 303                                                remove_incoming_requests: Default::default(),
 304                                                outgoing_requests: Default::default(),
 305                                                remove_outgoing_requests: Default::default(),
 306                                            },
 307                                        )
 308                                        .trace_err();
 309                                    }
 310                                }
 311                            }
 312                        }
 313                    }
 314
 315                    if let Some(live_kit) = live_kit_client.as_ref() {
 316                        if delete_live_kit_room {
 317                            live_kit.delete_room(live_kit_room).await.trace_err();
 318                        }
 319                    }
 320                }
 321            }
 322        });
 323        Ok(())
 324    }
 325
 326    pub fn teardown(&self) {
 327        self.peer.reset();
 328        let _ = self.teardown.send(());
 329    }
 330
 331    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 332    where
 333        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 334        Fut: 'static + Send + Future<Output = Result<()>>,
 335        M: EnvelopedMessage,
 336    {
 337        let prev_handler = self.handlers.insert(
 338            TypeId::of::<M>(),
 339            Box::new(move |envelope, session| {
 340                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 341                let span = info_span!(
 342                    "handle message",
 343                    payload_type = envelope.payload_type_name()
 344                );
 345                span.in_scope(|| {
 346                    tracing::info!(
 347                        payload_type = envelope.payload_type_name(),
 348                        "message received"
 349                    );
 350                });
 351                let future = (handler)(*envelope, session);
 352                async move {
 353                    if let Err(error) = future.await {
 354                        tracing::error!(%error, "error handling message");
 355                    }
 356                }
 357                .instrument(span)
 358                .boxed()
 359            }),
 360        );
 361        if prev_handler.is_some() {
 362            panic!("registered a handler for the same message twice");
 363        }
 364        self
 365    }
 366
 367    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 368    where
 369        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 370        Fut: 'static + Send + Future<Output = Result<()>>,
 371        M: EnvelopedMessage,
 372    {
 373        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 374        self
 375    }
 376
 377    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 378    where
 379        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 380        Fut: Send + Future<Output = Result<()>>,
 381        M: RequestMessage,
 382    {
 383        let handler = Arc::new(handler);
 384        self.add_handler(move |envelope, session| {
 385            let receipt = envelope.receipt();
 386            let handler = handler.clone();
 387            async move {
 388                let peer = session.peer.clone();
 389                let responded = Arc::new(AtomicBool::default());
 390                let response = Response {
 391                    peer: peer.clone(),
 392                    responded: responded.clone(),
 393                    receipt,
 394                };
 395                match (handler)(envelope.payload, response, session).await {
 396                    Ok(()) => {
 397                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 398                            Ok(())
 399                        } else {
 400                            Err(anyhow!("handler did not send a response"))?
 401                        }
 402                    }
 403                    Err(error) => {
 404                        peer.respond_with_error(
 405                            receipt,
 406                            proto::Error {
 407                                message: error.to_string(),
 408                            },
 409                        )?;
 410                        Err(error)
 411                    }
 412                }
 413            }
 414        })
 415    }
 416
 417    pub fn handle_connection(
 418        self: &Arc<Self>,
 419        connection: Connection,
 420        address: String,
 421        user: User,
 422        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 423        executor: Executor,
 424    ) -> impl Future<Output = Result<()>> {
 425        let this = self.clone();
 426        let user_id = user.id;
 427        let login = user.github_login;
 428        let span = info_span!("handle connection", %user_id, %login, %address);
 429        let mut teardown = self.teardown.subscribe();
 430        async move {
 431            let (connection_id, handle_io, mut incoming_rx) = this
 432                .peer
 433                .add_connection(connection, {
 434                    let executor = executor.clone();
 435                    move |duration| executor.sleep(duration)
 436                });
 437
 438            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 439            this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
 440            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 441
 442            if let Some(send_connection_id) = send_connection_id.take() {
 443                let _ = send_connection_id.send(connection_id);
 444            }
 445
 446            if !user.connected_once {
 447                this.peer.send(connection_id, proto::ShowContacts {})?;
 448                this.app_state.db.set_user_connected_once(user_id, true).await?;
 449            }
 450
 451            let (contacts, invite_code) = future::try_join(
 452                this.app_state.db.get_contacts(user_id),
 453                this.app_state.db.get_invite_code_for_user(user_id)
 454            ).await?;
 455
 456            {
 457                let mut pool = this.connection_pool.lock();
 458                pool.add_connection(connection_id, user_id, user.admin);
 459                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 460
 461                if let Some((code, count)) = invite_code {
 462                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 463                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 464                        count: count as u32,
 465                    })?;
 466                }
 467            }
 468
 469            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 470                this.peer.send(connection_id, incoming_call)?;
 471            }
 472
 473            let session = Session {
 474                user_id,
 475                connection_id,
 476                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 477                peer: this.peer.clone(),
 478                connection_pool: this.connection_pool.clone(),
 479                live_kit_client: this.app_state.live_kit_client.clone()
 480            };
 481            update_user_contacts(user_id, &session).await?;
 482
 483            let handle_io = handle_io.fuse();
 484            futures::pin_mut!(handle_io);
 485
 486            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 487            // This prevents deadlocks when e.g., client A performs a request to client B and
 488            // client B performs a request to client A. If both clients stop processing further
 489            // messages until their respective request completes, they won't have a chance to
 490            // respond to the other client's request and cause a deadlock.
 491            //
 492            // This arrangement ensures we will attempt to process earlier messages first, but fall
 493            // back to processing messages arrived later in the spirit of making progress.
 494            let mut foreground_message_handlers = FuturesUnordered::new();
 495            loop {
 496                let next_message = incoming_rx.next().fuse();
 497                futures::pin_mut!(next_message);
 498                futures::select_biased! {
 499                    _ = teardown.changed().fuse() => return Ok(()),
 500                    result = handle_io => {
 501                        if let Err(error) = result {
 502                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 503                        }
 504                        break;
 505                    }
 506                    _ = foreground_message_handlers.next() => {}
 507                    message = next_message => {
 508                        if let Some(message) = message {
 509                            let type_name = message.payload_type_name();
 510                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 511                            let span_enter = span.enter();
 512                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 513                                let is_background = message.is_background();
 514                                let handle_message = (handler)(message, session.clone());
 515                                drop(span_enter);
 516
 517                                let handle_message = handle_message.instrument(span);
 518                                if is_background {
 519                                    executor.spawn_detached(handle_message);
 520                                } else {
 521                                    foreground_message_handlers.push(handle_message);
 522                                }
 523                            } else {
 524                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 525                            }
 526                        } else {
 527                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 528                            break;
 529                        }
 530                    }
 531                }
 532            }
 533
 534            drop(foreground_message_handlers);
 535            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 536            if let Err(error) = sign_out(session, teardown, executor).await {
 537                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 538            }
 539
 540            Ok(())
 541        }.instrument(span)
 542    }
 543
 544    pub async fn invite_code_redeemed(
 545        self: &Arc<Self>,
 546        inviter_id: UserId,
 547        invitee_id: UserId,
 548    ) -> Result<()> {
 549        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 550            if let Some(code) = &user.invite_code {
 551                let pool = self.connection_pool.lock();
 552                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 553                for connection_id in pool.user_connection_ids(inviter_id) {
 554                    self.peer.send(
 555                        connection_id,
 556                        proto::UpdateContacts {
 557                            contacts: vec![invitee_contact.clone()],
 558                            ..Default::default()
 559                        },
 560                    )?;
 561                    self.peer.send(
 562                        connection_id,
 563                        proto::UpdateInviteInfo {
 564                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 565                            count: user.invite_count as u32,
 566                        },
 567                    )?;
 568                }
 569            }
 570        }
 571        Ok(())
 572    }
 573
 574    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 575        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 576            if let Some(invite_code) = &user.invite_code {
 577                let pool = self.connection_pool.lock();
 578                for connection_id in pool.user_connection_ids(user_id) {
 579                    self.peer.send(
 580                        connection_id,
 581                        proto::UpdateInviteInfo {
 582                            url: format!(
 583                                "{}{}",
 584                                self.app_state.config.invite_link_prefix, invite_code
 585                            ),
 586                            count: user.invite_count as u32,
 587                        },
 588                    )?;
 589                }
 590            }
 591        }
 592        Ok(())
 593    }
 594
 595    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 596        ServerSnapshot {
 597            connection_pool: ConnectionPoolGuard {
 598                guard: self.connection_pool.lock(),
 599                _not_send: PhantomData,
 600            },
 601            peer: &self.peer,
 602        }
 603    }
 604}
 605
 606impl<'a> Deref for ConnectionPoolGuard<'a> {
 607    type Target = ConnectionPool;
 608
 609    fn deref(&self) -> &Self::Target {
 610        &*self.guard
 611    }
 612}
 613
 614impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 615    fn deref_mut(&mut self) -> &mut Self::Target {
 616        &mut *self.guard
 617    }
 618}
 619
 620impl<'a> Drop for ConnectionPoolGuard<'a> {
 621    fn drop(&mut self) {
 622        #[cfg(test)]
 623        self.check_invariants();
 624    }
 625}
 626
 627fn broadcast<F>(
 628    sender_id: ConnectionId,
 629    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 630    mut f: F,
 631) where
 632    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 633{
 634    for receiver_id in receiver_ids {
 635        if receiver_id != sender_id {
 636            f(receiver_id).trace_err();
 637        }
 638    }
 639}
 640
 641lazy_static! {
 642    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 643}
 644
 645pub struct ProtocolVersion(u32);
 646
 647impl Header for ProtocolVersion {
 648    fn name() -> &'static HeaderName {
 649        &ZED_PROTOCOL_VERSION
 650    }
 651
 652    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 653    where
 654        Self: Sized,
 655        I: Iterator<Item = &'i axum::http::HeaderValue>,
 656    {
 657        let version = values
 658            .next()
 659            .ok_or_else(axum::headers::Error::invalid)?
 660            .to_str()
 661            .map_err(|_| axum::headers::Error::invalid())?
 662            .parse()
 663            .map_err(|_| axum::headers::Error::invalid())?;
 664        Ok(Self(version))
 665    }
 666
 667    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 668        values.extend([self.0.to_string().parse().unwrap()]);
 669    }
 670}
 671
 672pub fn routes(server: Arc<Server>) -> Router<Body> {
 673    Router::new()
 674        .route("/rpc", get(handle_websocket_request))
 675        .layer(
 676            ServiceBuilder::new()
 677                .layer(Extension(server.app_state.clone()))
 678                .layer(middleware::from_fn(auth::validate_header)),
 679        )
 680        .route("/metrics", get(handle_metrics))
 681        .layer(Extension(server))
 682}
 683
 684pub async fn handle_websocket_request(
 685    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 686    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 687    Extension(server): Extension<Arc<Server>>,
 688    Extension(user): Extension<User>,
 689    ws: WebSocketUpgrade,
 690) -> axum::response::Response {
 691    if protocol_version != rpc::PROTOCOL_VERSION {
 692        return (
 693            StatusCode::UPGRADE_REQUIRED,
 694            "client must be upgraded".to_string(),
 695        )
 696            .into_response();
 697    }
 698    let socket_address = socket_address.to_string();
 699    ws.on_upgrade(move |socket| {
 700        use util::ResultExt;
 701        let socket = socket
 702            .map_ok(to_tungstenite_message)
 703            .err_into()
 704            .with(|message| async move { Ok(to_axum_message(message)) });
 705        let connection = Connection::new(Box::pin(socket));
 706        async move {
 707            server
 708                .handle_connection(connection, socket_address, user, None, Executor::Production)
 709                .await
 710                .log_err();
 711        }
 712    })
 713}
 714
 715pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 716    let connections = server
 717        .connection_pool
 718        .lock()
 719        .connections()
 720        .filter(|connection| !connection.admin)
 721        .count();
 722
 723    METRIC_CONNECTIONS.set(connections as _);
 724
 725    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 726    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 727
 728    let encoder = prometheus::TextEncoder::new();
 729    let metric_families = prometheus::gather();
 730    let encoded_metrics = encoder
 731        .encode_to_string(&metric_families)
 732        .map_err(|err| anyhow!("{}", err))?;
 733    Ok(encoded_metrics)
 734}
 735
 736#[instrument(err, skip(executor))]
 737async fn sign_out(
 738    session: Session,
 739    mut teardown: watch::Receiver<()>,
 740    executor: Executor,
 741) -> Result<()> {
 742    session.peer.disconnect(session.connection_id);
 743    session
 744        .connection_pool()
 745        .await
 746        .remove_connection(session.connection_id)?;
 747
 748    if let Some(mut left_projects) = session
 749        .db()
 750        .await
 751        .connection_lost(session.connection_id)
 752        .await
 753        .trace_err()
 754    {
 755        for left_project in mem::take(&mut *left_projects) {
 756            project_left(&left_project, &session);
 757        }
 758    }
 759
 760    futures::select_biased! {
 761        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 762            leave_room_for_session(&session).await.trace_err();
 763
 764            if !session
 765                .connection_pool()
 766                .await
 767                .is_user_online(session.user_id)
 768            {
 769                let db = session.db().await;
 770                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
 771                    room_updated(&room, &session.peer);
 772                }
 773            }
 774            update_user_contacts(session.user_id, &session).await?;
 775        }
 776        _ = teardown.changed().fuse() => {}
 777    }
 778
 779    Ok(())
 780}
 781
 782async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 783    response.send(proto::Ack {})?;
 784    Ok(())
 785}
 786
 787async fn create_room(
 788    _request: proto::CreateRoom,
 789    response: Response<proto::CreateRoom>,
 790    session: Session,
 791) -> Result<()> {
 792    let live_kit_room = nanoid::nanoid!(30);
 793    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 794        if let Some(_) = live_kit
 795            .create_room(live_kit_room.clone())
 796            .await
 797            .trace_err()
 798        {
 799            if let Some(token) = live_kit
 800                .room_token(&live_kit_room, &session.connection_id.to_string())
 801                .trace_err()
 802            {
 803                Some(proto::LiveKitConnectionInfo {
 804                    server_url: live_kit.url().into(),
 805                    token,
 806                })
 807            } else {
 808                None
 809            }
 810        } else {
 811            None
 812        }
 813    } else {
 814        None
 815    };
 816
 817    {
 818        let room = session
 819            .db()
 820            .await
 821            .create_room(session.user_id, session.connection_id, &live_kit_room)
 822            .await?;
 823
 824        response.send(proto::CreateRoomResponse {
 825            room: Some(room.clone()),
 826            live_kit_connection_info,
 827        })?;
 828    }
 829
 830    update_user_contacts(session.user_id, &session).await?;
 831    Ok(())
 832}
 833
 834async fn join_room(
 835    request: proto::JoinRoom,
 836    response: Response<proto::JoinRoom>,
 837    session: Session,
 838) -> Result<()> {
 839    let room_id = RoomId::from_proto(request.id);
 840    let room = {
 841        let room = session
 842            .db()
 843            .await
 844            .join_room(room_id, session.user_id, session.connection_id)
 845            .await?;
 846        room_updated(&room, &session.peer);
 847        room.clone()
 848    };
 849
 850    for connection_id in session
 851        .connection_pool()
 852        .await
 853        .user_connection_ids(session.user_id)
 854    {
 855        session
 856            .peer
 857            .send(
 858                connection_id,
 859                proto::CallCanceled {
 860                    room_id: room_id.to_proto(),
 861                },
 862            )
 863            .trace_err();
 864    }
 865
 866    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 867        if let Some(token) = live_kit
 868            .room_token(&room.live_kit_room, &session.connection_id.to_string())
 869            .trace_err()
 870        {
 871            Some(proto::LiveKitConnectionInfo {
 872                server_url: live_kit.url().into(),
 873                token,
 874            })
 875        } else {
 876            None
 877        }
 878    } else {
 879        None
 880    };
 881
 882    response.send(proto::JoinRoomResponse {
 883        room: Some(room),
 884        live_kit_connection_info,
 885    })?;
 886
 887    update_user_contacts(session.user_id, &session).await?;
 888    Ok(())
 889}
 890
 891async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
 892    leave_room_for_session(&session).await
 893}
 894
 895async fn call(
 896    request: proto::Call,
 897    response: Response<proto::Call>,
 898    session: Session,
 899) -> Result<()> {
 900    let room_id = RoomId::from_proto(request.room_id);
 901    let calling_user_id = session.user_id;
 902    let calling_connection_id = session.connection_id;
 903    let called_user_id = UserId::from_proto(request.called_user_id);
 904    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
 905    if !session
 906        .db()
 907        .await
 908        .has_contact(calling_user_id, called_user_id)
 909        .await?
 910    {
 911        return Err(anyhow!("cannot call a user who isn't a contact"))?;
 912    }
 913
 914    let incoming_call = {
 915        let (room, incoming_call) = &mut *session
 916            .db()
 917            .await
 918            .call(
 919                room_id,
 920                calling_user_id,
 921                calling_connection_id,
 922                called_user_id,
 923                initial_project_id,
 924            )
 925            .await?;
 926        room_updated(&room, &session.peer);
 927        mem::take(incoming_call)
 928    };
 929    update_user_contacts(called_user_id, &session).await?;
 930
 931    let mut calls = session
 932        .connection_pool()
 933        .await
 934        .user_connection_ids(called_user_id)
 935        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
 936        .collect::<FuturesUnordered<_>>();
 937
 938    while let Some(call_response) = calls.next().await {
 939        match call_response.as_ref() {
 940            Ok(_) => {
 941                response.send(proto::Ack {})?;
 942                return Ok(());
 943            }
 944            Err(_) => {
 945                call_response.trace_err();
 946            }
 947        }
 948    }
 949
 950    {
 951        let room = session
 952            .db()
 953            .await
 954            .call_failed(room_id, called_user_id)
 955            .await?;
 956        room_updated(&room, &session.peer);
 957    }
 958    update_user_contacts(called_user_id, &session).await?;
 959
 960    Err(anyhow!("failed to ring user"))?
 961}
 962
 963async fn cancel_call(
 964    request: proto::CancelCall,
 965    response: Response<proto::CancelCall>,
 966    session: Session,
 967) -> Result<()> {
 968    let called_user_id = UserId::from_proto(request.called_user_id);
 969    let room_id = RoomId::from_proto(request.room_id);
 970    {
 971        let room = session
 972            .db()
 973            .await
 974            .cancel_call(Some(room_id), session.connection_id, called_user_id)
 975            .await?;
 976        room_updated(&room, &session.peer);
 977    }
 978
 979    for connection_id in session
 980        .connection_pool()
 981        .await
 982        .user_connection_ids(called_user_id)
 983    {
 984        session
 985            .peer
 986            .send(
 987                connection_id,
 988                proto::CallCanceled {
 989                    room_id: room_id.to_proto(),
 990                },
 991            )
 992            .trace_err();
 993    }
 994    response.send(proto::Ack {})?;
 995
 996    update_user_contacts(called_user_id, &session).await?;
 997    Ok(())
 998}
 999
1000async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1001    let room_id = RoomId::from_proto(message.room_id);
1002    {
1003        let room = session
1004            .db()
1005            .await
1006            .decline_call(Some(room_id), session.user_id)
1007            .await?;
1008        room_updated(&room, &session.peer);
1009    }
1010
1011    for connection_id in session
1012        .connection_pool()
1013        .await
1014        .user_connection_ids(session.user_id)
1015    {
1016        session
1017            .peer
1018            .send(
1019                connection_id,
1020                proto::CallCanceled {
1021                    room_id: room_id.to_proto(),
1022                },
1023            )
1024            .trace_err();
1025    }
1026    update_user_contacts(session.user_id, &session).await?;
1027    Ok(())
1028}
1029
1030async fn update_participant_location(
1031    request: proto::UpdateParticipantLocation,
1032    response: Response<proto::UpdateParticipantLocation>,
1033    session: Session,
1034) -> Result<()> {
1035    let room_id = RoomId::from_proto(request.room_id);
1036    let location = request
1037        .location
1038        .ok_or_else(|| anyhow!("invalid location"))?;
1039    let room = session
1040        .db()
1041        .await
1042        .update_room_participant_location(room_id, session.connection_id, location)
1043        .await?;
1044    room_updated(&room, &session.peer);
1045    response.send(proto::Ack {})?;
1046    Ok(())
1047}
1048
1049async fn share_project(
1050    request: proto::ShareProject,
1051    response: Response<proto::ShareProject>,
1052    session: Session,
1053) -> Result<()> {
1054    let (project_id, room) = &*session
1055        .db()
1056        .await
1057        .share_project(
1058            RoomId::from_proto(request.room_id),
1059            session.connection_id,
1060            &request.worktrees,
1061        )
1062        .await?;
1063    response.send(proto::ShareProjectResponse {
1064        project_id: project_id.to_proto(),
1065    })?;
1066    room_updated(&room, &session.peer);
1067
1068    Ok(())
1069}
1070
1071async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1072    let project_id = ProjectId::from_proto(message.project_id);
1073
1074    let (room, guest_connection_ids) = &*session
1075        .db()
1076        .await
1077        .unshare_project(project_id, session.connection_id)
1078        .await?;
1079
1080    broadcast(
1081        session.connection_id,
1082        guest_connection_ids.iter().copied(),
1083        |conn_id| session.peer.send(conn_id, message.clone()),
1084    );
1085    room_updated(&room, &session.peer);
1086
1087    Ok(())
1088}
1089
1090async fn join_project(
1091    request: proto::JoinProject,
1092    response: Response<proto::JoinProject>,
1093    session: Session,
1094) -> Result<()> {
1095    let project_id = ProjectId::from_proto(request.project_id);
1096    let guest_user_id = session.user_id;
1097
1098    tracing::info!(%project_id, "join project");
1099
1100    let (project, replica_id) = &mut *session
1101        .db()
1102        .await
1103        .join_project(project_id, session.connection_id)
1104        .await?;
1105
1106    let collaborators = project
1107        .collaborators
1108        .iter()
1109        .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
1110        .map(|collaborator| proto::Collaborator {
1111            peer_id: collaborator.connection_id as u32,
1112            replica_id: collaborator.replica_id.0 as u32,
1113            user_id: collaborator.user_id.to_proto(),
1114        })
1115        .collect::<Vec<_>>();
1116    let worktrees = project
1117        .worktrees
1118        .iter()
1119        .map(|(id, worktree)| proto::WorktreeMetadata {
1120            id: *id,
1121            root_name: worktree.root_name.clone(),
1122            visible: worktree.visible,
1123            abs_path: worktree.abs_path.clone(),
1124        })
1125        .collect::<Vec<_>>();
1126
1127    for collaborator in &collaborators {
1128        session
1129            .peer
1130            .send(
1131                ConnectionId(collaborator.peer_id),
1132                proto::AddProjectCollaborator {
1133                    project_id: project_id.to_proto(),
1134                    collaborator: Some(proto::Collaborator {
1135                        peer_id: session.connection_id.0,
1136                        replica_id: replica_id.0 as u32,
1137                        user_id: guest_user_id.to_proto(),
1138                    }),
1139                },
1140            )
1141            .trace_err();
1142    }
1143
1144    // First, we send the metadata associated with each worktree.
1145    response.send(proto::JoinProjectResponse {
1146        worktrees: worktrees.clone(),
1147        replica_id: replica_id.0 as u32,
1148        collaborators: collaborators.clone(),
1149        language_servers: project.language_servers.clone(),
1150    })?;
1151
1152    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1153        #[cfg(any(test, feature = "test-support"))]
1154        const MAX_CHUNK_SIZE: usize = 2;
1155        #[cfg(not(any(test, feature = "test-support")))]
1156        const MAX_CHUNK_SIZE: usize = 256;
1157
1158        // Stream this worktree's entries.
1159        let message = proto::UpdateWorktree {
1160            project_id: project_id.to_proto(),
1161            worktree_id,
1162            abs_path: worktree.abs_path.clone(),
1163            root_name: worktree.root_name,
1164            updated_entries: worktree.entries,
1165            removed_entries: Default::default(),
1166            scan_id: worktree.scan_id,
1167            is_last_update: worktree.is_complete,
1168        };
1169        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1170            session.peer.send(session.connection_id, update.clone())?;
1171        }
1172
1173        // Stream this worktree's diagnostics.
1174        for summary in worktree.diagnostic_summaries {
1175            session.peer.send(
1176                session.connection_id,
1177                proto::UpdateDiagnosticSummary {
1178                    project_id: project_id.to_proto(),
1179                    worktree_id: worktree.id,
1180                    summary: Some(summary),
1181                },
1182            )?;
1183        }
1184    }
1185
1186    for language_server in &project.language_servers {
1187        session.peer.send(
1188            session.connection_id,
1189            proto::UpdateLanguageServer {
1190                project_id: project_id.to_proto(),
1191                language_server_id: language_server.id,
1192                variant: Some(
1193                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1194                        proto::LspDiskBasedDiagnosticsUpdated {},
1195                    ),
1196                ),
1197            },
1198        )?;
1199    }
1200
1201    Ok(())
1202}
1203
1204async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1205    let sender_id = session.connection_id;
1206    let project_id = ProjectId::from_proto(request.project_id);
1207
1208    let project = session
1209        .db()
1210        .await
1211        .leave_project(project_id, sender_id)
1212        .await?;
1213    tracing::info!(
1214        %project_id,
1215        host_user_id = %project.host_user_id,
1216        host_connection_id = %project.host_connection_id,
1217        "leave project"
1218    );
1219    project_left(&project, &session);
1220
1221    Ok(())
1222}
1223
1224async fn update_project(
1225    request: proto::UpdateProject,
1226    response: Response<proto::UpdateProject>,
1227    session: Session,
1228) -> Result<()> {
1229    let project_id = ProjectId::from_proto(request.project_id);
1230    let (room, guest_connection_ids) = &*session
1231        .db()
1232        .await
1233        .update_project(project_id, session.connection_id, &request.worktrees)
1234        .await?;
1235    broadcast(
1236        session.connection_id,
1237        guest_connection_ids.iter().copied(),
1238        |connection_id| {
1239            session
1240                .peer
1241                .forward_send(session.connection_id, connection_id, request.clone())
1242        },
1243    );
1244    room_updated(&room, &session.peer);
1245    response.send(proto::Ack {})?;
1246
1247    Ok(())
1248}
1249
1250async fn update_worktree(
1251    request: proto::UpdateWorktree,
1252    response: Response<proto::UpdateWorktree>,
1253    session: Session,
1254) -> Result<()> {
1255    let guest_connection_ids = session
1256        .db()
1257        .await
1258        .update_worktree(&request, session.connection_id)
1259        .await?;
1260
1261    broadcast(
1262        session.connection_id,
1263        guest_connection_ids.iter().copied(),
1264        |connection_id| {
1265            session
1266                .peer
1267                .forward_send(session.connection_id, connection_id, request.clone())
1268        },
1269    );
1270    response.send(proto::Ack {})?;
1271    Ok(())
1272}
1273
1274async fn update_diagnostic_summary(
1275    message: proto::UpdateDiagnosticSummary,
1276    session: Session,
1277) -> Result<()> {
1278    let guest_connection_ids = session
1279        .db()
1280        .await
1281        .update_diagnostic_summary(&message, session.connection_id)
1282        .await?;
1283
1284    broadcast(
1285        session.connection_id,
1286        guest_connection_ids.iter().copied(),
1287        |connection_id| {
1288            session
1289                .peer
1290                .forward_send(session.connection_id, connection_id, message.clone())
1291        },
1292    );
1293
1294    Ok(())
1295}
1296
1297async fn start_language_server(
1298    request: proto::StartLanguageServer,
1299    session: Session,
1300) -> Result<()> {
1301    let guest_connection_ids = session
1302        .db()
1303        .await
1304        .start_language_server(&request, session.connection_id)
1305        .await?;
1306
1307    broadcast(
1308        session.connection_id,
1309        guest_connection_ids.iter().copied(),
1310        |connection_id| {
1311            session
1312                .peer
1313                .forward_send(session.connection_id, connection_id, request.clone())
1314        },
1315    );
1316    Ok(())
1317}
1318
1319async fn update_language_server(
1320    request: proto::UpdateLanguageServer,
1321    session: Session,
1322) -> Result<()> {
1323    let project_id = ProjectId::from_proto(request.project_id);
1324    let project_connection_ids = session
1325        .db()
1326        .await
1327        .project_connection_ids(project_id, session.connection_id)
1328        .await?;
1329    broadcast(
1330        session.connection_id,
1331        project_connection_ids.iter().copied(),
1332        |connection_id| {
1333            session
1334                .peer
1335                .forward_send(session.connection_id, connection_id, request.clone())
1336        },
1337    );
1338    Ok(())
1339}
1340
1341async fn forward_project_request<T>(
1342    request: T,
1343    response: Response<T>,
1344    session: Session,
1345) -> Result<()>
1346where
1347    T: EntityMessage + RequestMessage,
1348{
1349    let project_id = ProjectId::from_proto(request.remote_entity_id());
1350    let host_connection_id = {
1351        let collaborators = session
1352            .db()
1353            .await
1354            .project_collaborators(project_id, session.connection_id)
1355            .await?;
1356        ConnectionId(
1357            collaborators
1358                .iter()
1359                .find(|collaborator| collaborator.is_host)
1360                .ok_or_else(|| anyhow!("host not found"))?
1361                .connection_id as u32,
1362        )
1363    };
1364
1365    let payload = session
1366        .peer
1367        .forward_request(session.connection_id, host_connection_id, request)
1368        .await?;
1369
1370    response.send(payload)?;
1371    Ok(())
1372}
1373
1374async fn save_buffer(
1375    request: proto::SaveBuffer,
1376    response: Response<proto::SaveBuffer>,
1377    session: Session,
1378) -> Result<()> {
1379    let project_id = ProjectId::from_proto(request.project_id);
1380    let host_connection_id = {
1381        let collaborators = session
1382            .db()
1383            .await
1384            .project_collaborators(project_id, session.connection_id)
1385            .await?;
1386        let host = collaborators
1387            .iter()
1388            .find(|collaborator| collaborator.is_host)
1389            .ok_or_else(|| anyhow!("host not found"))?;
1390        ConnectionId(host.connection_id as u32)
1391    };
1392    let response_payload = session
1393        .peer
1394        .forward_request(session.connection_id, host_connection_id, request.clone())
1395        .await?;
1396
1397    let mut collaborators = session
1398        .db()
1399        .await
1400        .project_collaborators(project_id, session.connection_id)
1401        .await?;
1402    collaborators
1403        .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1404    let project_connection_ids = collaborators
1405        .iter()
1406        .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1407    broadcast(host_connection_id, project_connection_ids, |conn_id| {
1408        session
1409            .peer
1410            .forward_send(host_connection_id, conn_id, response_payload.clone())
1411    });
1412    response.send(response_payload)?;
1413    Ok(())
1414}
1415
1416async fn create_buffer_for_peer(
1417    request: proto::CreateBufferForPeer,
1418    session: Session,
1419) -> Result<()> {
1420    session.peer.forward_send(
1421        session.connection_id,
1422        ConnectionId(request.peer_id),
1423        request,
1424    )?;
1425    Ok(())
1426}
1427
1428async fn update_buffer(
1429    request: proto::UpdateBuffer,
1430    response: Response<proto::UpdateBuffer>,
1431    session: Session,
1432) -> Result<()> {
1433    let project_id = ProjectId::from_proto(request.project_id);
1434    let project_connection_ids = session
1435        .db()
1436        .await
1437        .project_connection_ids(project_id, session.connection_id)
1438        .await?;
1439
1440    broadcast(
1441        session.connection_id,
1442        project_connection_ids.iter().copied(),
1443        |connection_id| {
1444            session
1445                .peer
1446                .forward_send(session.connection_id, connection_id, request.clone())
1447        },
1448    );
1449    response.send(proto::Ack {})?;
1450    Ok(())
1451}
1452
1453async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1454    let project_id = ProjectId::from_proto(request.project_id);
1455    let project_connection_ids = session
1456        .db()
1457        .await
1458        .project_connection_ids(project_id, session.connection_id)
1459        .await?;
1460
1461    broadcast(
1462        session.connection_id,
1463        project_connection_ids.iter().copied(),
1464        |connection_id| {
1465            session
1466                .peer
1467                .forward_send(session.connection_id, connection_id, request.clone())
1468        },
1469    );
1470    Ok(())
1471}
1472
1473async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1474    let project_id = ProjectId::from_proto(request.project_id);
1475    let project_connection_ids = session
1476        .db()
1477        .await
1478        .project_connection_ids(project_id, session.connection_id)
1479        .await?;
1480    broadcast(
1481        session.connection_id,
1482        project_connection_ids.iter().copied(),
1483        |connection_id| {
1484            session
1485                .peer
1486                .forward_send(session.connection_id, connection_id, request.clone())
1487        },
1488    );
1489    Ok(())
1490}
1491
1492async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1493    let project_id = ProjectId::from_proto(request.project_id);
1494    let project_connection_ids = session
1495        .db()
1496        .await
1497        .project_connection_ids(project_id, session.connection_id)
1498        .await?;
1499    broadcast(
1500        session.connection_id,
1501        project_connection_ids.iter().copied(),
1502        |connection_id| {
1503            session
1504                .peer
1505                .forward_send(session.connection_id, connection_id, request.clone())
1506        },
1507    );
1508    Ok(())
1509}
1510
1511async fn follow(
1512    request: proto::Follow,
1513    response: Response<proto::Follow>,
1514    session: Session,
1515) -> Result<()> {
1516    let project_id = ProjectId::from_proto(request.project_id);
1517    let leader_id = ConnectionId(request.leader_id);
1518    let follower_id = session.connection_id;
1519    {
1520        let project_connection_ids = session
1521            .db()
1522            .await
1523            .project_connection_ids(project_id, session.connection_id)
1524            .await?;
1525
1526        if !project_connection_ids.contains(&leader_id) {
1527            Err(anyhow!("no such peer"))?;
1528        }
1529    }
1530
1531    let mut response_payload = session
1532        .peer
1533        .forward_request(session.connection_id, leader_id, request)
1534        .await?;
1535    response_payload
1536        .views
1537        .retain(|view| view.leader_id != Some(follower_id.0));
1538    response.send(response_payload)?;
1539    Ok(())
1540}
1541
1542async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1543    let project_id = ProjectId::from_proto(request.project_id);
1544    let leader_id = ConnectionId(request.leader_id);
1545    let project_connection_ids = session
1546        .db()
1547        .await
1548        .project_connection_ids(project_id, session.connection_id)
1549        .await?;
1550    if !project_connection_ids.contains(&leader_id) {
1551        Err(anyhow!("no such peer"))?;
1552    }
1553    session
1554        .peer
1555        .forward_send(session.connection_id, leader_id, request)?;
1556    Ok(())
1557}
1558
1559async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1560    let project_id = ProjectId::from_proto(request.project_id);
1561    let project_connection_ids = session
1562        .db
1563        .lock()
1564        .await
1565        .project_connection_ids(project_id, session.connection_id)
1566        .await?;
1567
1568    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1569        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1570        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1571        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1572    });
1573    for follower_id in &request.follower_ids {
1574        let follower_id = ConnectionId(*follower_id);
1575        if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1576            session
1577                .peer
1578                .forward_send(session.connection_id, follower_id, request.clone())?;
1579        }
1580    }
1581    Ok(())
1582}
1583
1584async fn get_users(
1585    request: proto::GetUsers,
1586    response: Response<proto::GetUsers>,
1587    session: Session,
1588) -> Result<()> {
1589    let user_ids = request
1590        .user_ids
1591        .into_iter()
1592        .map(UserId::from_proto)
1593        .collect();
1594    let users = session
1595        .db()
1596        .await
1597        .get_users_by_ids(user_ids)
1598        .await?
1599        .into_iter()
1600        .map(|user| proto::User {
1601            id: user.id.to_proto(),
1602            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1603            github_login: user.github_login,
1604        })
1605        .collect();
1606    response.send(proto::UsersResponse { users })?;
1607    Ok(())
1608}
1609
1610async fn fuzzy_search_users(
1611    request: proto::FuzzySearchUsers,
1612    response: Response<proto::FuzzySearchUsers>,
1613    session: Session,
1614) -> Result<()> {
1615    let query = request.query;
1616    let users = match query.len() {
1617        0 => vec![],
1618        1 | 2 => session
1619            .db()
1620            .await
1621            .get_user_by_github_account(&query, None)
1622            .await?
1623            .into_iter()
1624            .collect(),
1625        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1626    };
1627    let users = users
1628        .into_iter()
1629        .filter(|user| user.id != session.user_id)
1630        .map(|user| proto::User {
1631            id: user.id.to_proto(),
1632            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1633            github_login: user.github_login,
1634        })
1635        .collect();
1636    response.send(proto::UsersResponse { users })?;
1637    Ok(())
1638}
1639
1640async fn request_contact(
1641    request: proto::RequestContact,
1642    response: Response<proto::RequestContact>,
1643    session: Session,
1644) -> Result<()> {
1645    let requester_id = session.user_id;
1646    let responder_id = UserId::from_proto(request.responder_id);
1647    if requester_id == responder_id {
1648        return Err(anyhow!("cannot add yourself as a contact"))?;
1649    }
1650
1651    session
1652        .db()
1653        .await
1654        .send_contact_request(requester_id, responder_id)
1655        .await?;
1656
1657    // Update outgoing contact requests of requester
1658    let mut update = proto::UpdateContacts::default();
1659    update.outgoing_requests.push(responder_id.to_proto());
1660    for connection_id in session
1661        .connection_pool()
1662        .await
1663        .user_connection_ids(requester_id)
1664    {
1665        session.peer.send(connection_id, update.clone())?;
1666    }
1667
1668    // Update incoming contact requests of responder
1669    let mut update = proto::UpdateContacts::default();
1670    update
1671        .incoming_requests
1672        .push(proto::IncomingContactRequest {
1673            requester_id: requester_id.to_proto(),
1674            should_notify: true,
1675        });
1676    for connection_id in session
1677        .connection_pool()
1678        .await
1679        .user_connection_ids(responder_id)
1680    {
1681        session.peer.send(connection_id, update.clone())?;
1682    }
1683
1684    response.send(proto::Ack {})?;
1685    Ok(())
1686}
1687
1688async fn respond_to_contact_request(
1689    request: proto::RespondToContactRequest,
1690    response: Response<proto::RespondToContactRequest>,
1691    session: Session,
1692) -> Result<()> {
1693    let responder_id = session.user_id;
1694    let requester_id = UserId::from_proto(request.requester_id);
1695    let db = session.db().await;
1696    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1697        db.dismiss_contact_notification(responder_id, requester_id)
1698            .await?;
1699    } else {
1700        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1701
1702        db.respond_to_contact_request(responder_id, requester_id, accept)
1703            .await?;
1704        let requester_busy = db.is_user_busy(requester_id).await?;
1705        let responder_busy = db.is_user_busy(responder_id).await?;
1706
1707        let pool = session.connection_pool().await;
1708        // Update responder with new contact
1709        let mut update = proto::UpdateContacts::default();
1710        if accept {
1711            update
1712                .contacts
1713                .push(contact_for_user(requester_id, false, requester_busy, &pool));
1714        }
1715        update
1716            .remove_incoming_requests
1717            .push(requester_id.to_proto());
1718        for connection_id in pool.user_connection_ids(responder_id) {
1719            session.peer.send(connection_id, update.clone())?;
1720        }
1721
1722        // Update requester with new contact
1723        let mut update = proto::UpdateContacts::default();
1724        if accept {
1725            update
1726                .contacts
1727                .push(contact_for_user(responder_id, true, responder_busy, &pool));
1728        }
1729        update
1730            .remove_outgoing_requests
1731            .push(responder_id.to_proto());
1732        for connection_id in pool.user_connection_ids(requester_id) {
1733            session.peer.send(connection_id, update.clone())?;
1734        }
1735    }
1736
1737    response.send(proto::Ack {})?;
1738    Ok(())
1739}
1740
1741async fn remove_contact(
1742    request: proto::RemoveContact,
1743    response: Response<proto::RemoveContact>,
1744    session: Session,
1745) -> Result<()> {
1746    let requester_id = session.user_id;
1747    let responder_id = UserId::from_proto(request.user_id);
1748    let db = session.db().await;
1749    db.remove_contact(requester_id, responder_id).await?;
1750
1751    let pool = session.connection_pool().await;
1752    // Update outgoing contact requests of requester
1753    let mut update = proto::UpdateContacts::default();
1754    update
1755        .remove_outgoing_requests
1756        .push(responder_id.to_proto());
1757    for connection_id in pool.user_connection_ids(requester_id) {
1758        session.peer.send(connection_id, update.clone())?;
1759    }
1760
1761    // Update incoming contact requests of responder
1762    let mut update = proto::UpdateContacts::default();
1763    update
1764        .remove_incoming_requests
1765        .push(requester_id.to_proto());
1766    for connection_id in pool.user_connection_ids(responder_id) {
1767        session.peer.send(connection_id, update.clone())?;
1768    }
1769
1770    response.send(proto::Ack {})?;
1771    Ok(())
1772}
1773
1774async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1775    let project_id = ProjectId::from_proto(request.project_id);
1776    let project_connection_ids = session
1777        .db()
1778        .await
1779        .project_connection_ids(project_id, session.connection_id)
1780        .await?;
1781    broadcast(
1782        session.connection_id,
1783        project_connection_ids.iter().copied(),
1784        |connection_id| {
1785            session
1786                .peer
1787                .forward_send(session.connection_id, connection_id, request.clone())
1788        },
1789    );
1790    Ok(())
1791}
1792
1793async fn get_private_user_info(
1794    _request: proto::GetPrivateUserInfo,
1795    response: Response<proto::GetPrivateUserInfo>,
1796    session: Session,
1797) -> Result<()> {
1798    let metrics_id = session
1799        .db()
1800        .await
1801        .get_user_metrics_id(session.user_id)
1802        .await?;
1803    let user = session
1804        .db()
1805        .await
1806        .get_user_by_id(session.user_id)
1807        .await?
1808        .ok_or_else(|| anyhow!("user not found"))?;
1809    response.send(proto::GetPrivateUserInfoResponse {
1810        metrics_id,
1811        staff: user.admin,
1812    })?;
1813    Ok(())
1814}
1815
1816fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1817    match message {
1818        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1819        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1820        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1821        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1822        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1823            code: frame.code.into(),
1824            reason: frame.reason,
1825        })),
1826    }
1827}
1828
1829fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1830    match message {
1831        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1832        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1833        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1834        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1835        AxumMessage::Close(frame) => {
1836            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1837                code: frame.code.into(),
1838                reason: frame.reason,
1839            }))
1840        }
1841    }
1842}
1843
1844fn build_initial_contacts_update(
1845    contacts: Vec<db::Contact>,
1846    pool: &ConnectionPool,
1847) -> proto::UpdateContacts {
1848    let mut update = proto::UpdateContacts::default();
1849
1850    for contact in contacts {
1851        match contact {
1852            db::Contact::Accepted {
1853                user_id,
1854                should_notify,
1855                busy,
1856            } => {
1857                update
1858                    .contacts
1859                    .push(contact_for_user(user_id, should_notify, busy, &pool));
1860            }
1861            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1862            db::Contact::Incoming {
1863                user_id,
1864                should_notify,
1865            } => update
1866                .incoming_requests
1867                .push(proto::IncomingContactRequest {
1868                    requester_id: user_id.to_proto(),
1869                    should_notify,
1870                }),
1871        }
1872    }
1873
1874    update
1875}
1876
1877fn contact_for_user(
1878    user_id: UserId,
1879    should_notify: bool,
1880    busy: bool,
1881    pool: &ConnectionPool,
1882) -> proto::Contact {
1883    proto::Contact {
1884        user_id: user_id.to_proto(),
1885        online: pool.is_user_online(user_id),
1886        busy,
1887        should_notify,
1888    }
1889}
1890
1891fn room_updated(room: &proto::Room, peer: &Peer) {
1892    for participant in &room.participants {
1893        peer.send(
1894            ConnectionId(participant.peer_id),
1895            proto::RoomUpdated {
1896                room: Some(room.clone()),
1897            },
1898        )
1899        .trace_err();
1900    }
1901}
1902
1903async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1904    let db = session.db().await;
1905    let contacts = db.get_contacts(user_id).await?;
1906    let busy = db.is_user_busy(user_id).await?;
1907
1908    let pool = session.connection_pool().await;
1909    let updated_contact = contact_for_user(user_id, false, busy, &pool);
1910    for contact in contacts {
1911        if let db::Contact::Accepted {
1912            user_id: contact_user_id,
1913            ..
1914        } = contact
1915        {
1916            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1917                session
1918                    .peer
1919                    .send(
1920                        contact_conn_id,
1921                        proto::UpdateContacts {
1922                            contacts: vec![updated_contact.clone()],
1923                            remove_contacts: Default::default(),
1924                            incoming_requests: Default::default(),
1925                            remove_incoming_requests: Default::default(),
1926                            outgoing_requests: Default::default(),
1927                            remove_outgoing_requests: Default::default(),
1928                        },
1929                    )
1930                    .trace_err();
1931            }
1932        }
1933    }
1934    Ok(())
1935}
1936
1937async fn leave_room_for_session(session: &Session) -> Result<()> {
1938    let mut contacts_to_update = HashSet::default();
1939
1940    let room_id;
1941    let canceled_calls_to_user_ids;
1942    let live_kit_room;
1943    let delete_live_kit_room;
1944    {
1945        let mut left_room = session.db().await.leave_room(session.connection_id).await?;
1946        contacts_to_update.insert(session.user_id);
1947
1948        for project in left_room.left_projects.values() {
1949            project_left(project, session);
1950        }
1951
1952        room_updated(&left_room.room, &session.peer);
1953        room_id = RoomId::from_proto(left_room.room.id);
1954        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
1955        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
1956        delete_live_kit_room = left_room.room.participants.is_empty();
1957    }
1958
1959    {
1960        let pool = session.connection_pool().await;
1961        for canceled_user_id in canceled_calls_to_user_ids {
1962            for connection_id in pool.user_connection_ids(canceled_user_id) {
1963                session
1964                    .peer
1965                    .send(
1966                        connection_id,
1967                        proto::CallCanceled {
1968                            room_id: room_id.to_proto(),
1969                        },
1970                    )
1971                    .trace_err();
1972            }
1973            contacts_to_update.insert(canceled_user_id);
1974        }
1975    }
1976
1977    for contact_user_id in contacts_to_update {
1978        update_user_contacts(contact_user_id, &session).await?;
1979    }
1980
1981    if let Some(live_kit) = session.live_kit_client.as_ref() {
1982        live_kit
1983            .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
1984            .await
1985            .trace_err();
1986
1987        if delete_live_kit_room {
1988            live_kit.delete_room(live_kit_room).await.trace_err();
1989        }
1990    }
1991
1992    Ok(())
1993}
1994
1995fn project_left(project: &db::LeftProject, session: &Session) {
1996    for connection_id in &project.connection_ids {
1997        if project.host_user_id == session.user_id {
1998            session
1999                .peer
2000                .send(
2001                    *connection_id,
2002                    proto::UnshareProject {
2003                        project_id: project.id.to_proto(),
2004                    },
2005                )
2006                .trace_err();
2007        } else {
2008            session
2009                .peer
2010                .send(
2011                    *connection_id,
2012                    proto::RemoveProjectCollaborator {
2013                        project_id: project.id.to_proto(),
2014                        peer_id: session.connection_id.0,
2015                    },
2016                )
2017                .trace_err();
2018        }
2019    }
2020
2021    session
2022        .peer
2023        .send(
2024            session.connection_id,
2025            proto::UnshareProject {
2026                project_id: project.id.to_proto(),
2027            },
2028        )
2029        .trace_err();
2030}
2031
2032pub trait ResultExt {
2033    type Ok;
2034
2035    fn trace_err(self) -> Option<Self::Ok>;
2036}
2037
2038impl<T, E> ResultExt for Result<T, E>
2039where
2040    E: std::fmt::Debug,
2041{
2042    type Ok = T;
2043
2044    fn trace_err(self) -> Option<T> {
2045        match self {
2046            Ok(value) => Some(value),
2047            Err(error) => {
2048                tracing::error!("{:?}", error);
2049                None
2050            }
2051        }
2052    }
2053}