rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, Database, ProjectId, RoomId, User, UserId},
   6    executor::Executor,
   7    AppState, Result,
   8};
   9use anyhow::anyhow;
  10use async_tungstenite::tungstenite::{
  11    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  12};
  13use axum::{
  14    body::Body,
  15    extract::{
  16        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  17        ConnectInfo, WebSocketUpgrade,
  18    },
  19    headers::{Header, HeaderName},
  20    http::StatusCode,
  21    middleware,
  22    response::IntoResponse,
  23    routing::get,
  24    Extension, Router, TypedHeader,
  25};
  26use collections::{HashMap, HashSet};
  27pub use connection_pool::ConnectionPool;
  28use futures::{
  29    channel::oneshot,
  30    future::{self, BoxFuture},
  31    stream::FuturesUnordered,
  32    FutureExt, SinkExt, StreamExt, TryStreamExt,
  33};
  34use lazy_static::lazy_static;
  35use prometheus::{register_int_gauge, IntGauge};
  36use rpc::{
  37    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  38    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  39};
  40use serde::{Serialize, Serializer};
  41use std::{
  42    any::TypeId,
  43    fmt,
  44    future::Future,
  45    marker::PhantomData,
  46    mem,
  47    net::SocketAddr,
  48    ops::{Deref, DerefMut},
  49    rc::Rc,
  50    sync::{
  51        atomic::{AtomicBool, Ordering::SeqCst},
  52        Arc,
  53    },
  54    time::Duration,
  55};
  56use tokio::sync::{watch, Mutex, MutexGuard};
  57use tower::ServiceBuilder;
  58use tracing::{info_span, instrument, Instrument};
  59
  60pub const RECONNECT_TIMEOUT: Duration = rpc::RECEIVE_TIMEOUT;
  61
  62lazy_static! {
  63    static ref METRIC_CONNECTIONS: IntGauge =
  64        register_int_gauge!("connections", "number of connections").unwrap();
  65    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  66        "shared_projects",
  67        "number of open projects with one or more guests"
  68    )
  69    .unwrap();
  70}
  71
  72type MessageHandler =
  73    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  74
  75struct Response<R> {
  76    peer: Arc<Peer>,
  77    receipt: Receipt<R>,
  78    responded: Arc<AtomicBool>,
  79}
  80
  81impl<R: RequestMessage> Response<R> {
  82    fn send(self, payload: R::Response) -> Result<()> {
  83        self.responded.store(true, SeqCst);
  84        self.peer.respond(self.receipt, payload)?;
  85        Ok(())
  86    }
  87}
  88
  89#[derive(Clone)]
  90struct Session {
  91    user_id: UserId,
  92    connection_id: ConnectionId,
  93    db: Arc<Mutex<DbHandle>>,
  94    peer: Arc<Peer>,
  95    connection_pool: Arc<Mutex<ConnectionPool>>,
  96    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
  97}
  98
  99impl Session {
 100    async fn db(&self) -> MutexGuard<DbHandle> {
 101        #[cfg(test)]
 102        tokio::task::yield_now().await;
 103        let guard = self.db.lock().await;
 104        #[cfg(test)]
 105        tokio::task::yield_now().await;
 106        guard
 107    }
 108
 109    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 110        #[cfg(test)]
 111        tokio::task::yield_now().await;
 112        let guard = self.connection_pool.lock().await;
 113        #[cfg(test)]
 114        tokio::task::yield_now().await;
 115        ConnectionPoolGuard {
 116            guard,
 117            _not_send: PhantomData,
 118        }
 119    }
 120}
 121
 122impl fmt::Debug for Session {
 123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 124        f.debug_struct("Session")
 125            .field("user_id", &self.user_id)
 126            .field("connection_id", &self.connection_id)
 127            .finish()
 128    }
 129}
 130
 131struct DbHandle(Arc<Database>);
 132
 133impl Deref for DbHandle {
 134    type Target = Database;
 135
 136    fn deref(&self) -> &Self::Target {
 137        self.0.as_ref()
 138    }
 139}
 140
 141pub struct Server {
 142    peer: Arc<Peer>,
 143    pub(crate) connection_pool: Arc<Mutex<ConnectionPool>>,
 144    app_state: Arc<AppState>,
 145    handlers: HashMap<TypeId, MessageHandler>,
 146    teardown: watch::Sender<()>,
 147}
 148
 149pub(crate) struct ConnectionPoolGuard<'a> {
 150    guard: MutexGuard<'a, ConnectionPool>,
 151    _not_send: PhantomData<Rc<()>>,
 152}
 153
 154#[derive(Serialize)]
 155pub struct ServerSnapshot<'a> {
 156    peer: &'a Peer,
 157    #[serde(serialize_with = "serialize_deref")]
 158    connection_pool: ConnectionPoolGuard<'a>,
 159}
 160
 161pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 162where
 163    S: Serializer,
 164    T: Deref<Target = U>,
 165    U: Serialize,
 166{
 167    Serialize::serialize(value.deref(), serializer)
 168}
 169
 170impl Server {
 171    pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
 172        let mut server = Self {
 173            peer: Peer::new(),
 174            app_state,
 175            connection_pool: Default::default(),
 176            handlers: Default::default(),
 177            teardown: watch::channel(()).0,
 178        };
 179
 180        server
 181            .add_request_handler(ping)
 182            .add_request_handler(create_room)
 183            .add_request_handler(join_room)
 184            .add_message_handler(leave_room)
 185            .add_request_handler(call)
 186            .add_request_handler(cancel_call)
 187            .add_message_handler(decline_call)
 188            .add_request_handler(update_participant_location)
 189            .add_request_handler(share_project)
 190            .add_message_handler(unshare_project)
 191            .add_request_handler(join_project)
 192            .add_message_handler(leave_project)
 193            .add_request_handler(update_project)
 194            .add_request_handler(update_worktree)
 195            .add_message_handler(start_language_server)
 196            .add_message_handler(update_language_server)
 197            .add_message_handler(update_diagnostic_summary)
 198            .add_request_handler(forward_project_request::<proto::GetHover>)
 199            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 200            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 201            .add_request_handler(forward_project_request::<proto::GetReferences>)
 202            .add_request_handler(forward_project_request::<proto::SearchProject>)
 203            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 204            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 205            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 206            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 207            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 208            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 209            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 210            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 211            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 212            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 213            .add_request_handler(forward_project_request::<proto::PerformRename>)
 214            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 215            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 216            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 217            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 218            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 219            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 220            .add_message_handler(create_buffer_for_peer)
 221            .add_request_handler(update_buffer)
 222            .add_message_handler(update_buffer_file)
 223            .add_message_handler(buffer_reloaded)
 224            .add_message_handler(buffer_saved)
 225            .add_request_handler(save_buffer)
 226            .add_request_handler(get_users)
 227            .add_request_handler(fuzzy_search_users)
 228            .add_request_handler(request_contact)
 229            .add_request_handler(remove_contact)
 230            .add_request_handler(respond_to_contact_request)
 231            .add_request_handler(follow)
 232            .add_message_handler(unfollow)
 233            .add_message_handler(update_followers)
 234            .add_message_handler(update_diff_base)
 235            .add_request_handler(get_private_user_info);
 236
 237        Arc::new(server)
 238    }
 239
 240    pub fn teardown(&self) {
 241        let _ = self.teardown.send(());
 242    }
 243
 244    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 245    where
 246        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 247        Fut: 'static + Send + Future<Output = Result<()>>,
 248        M: EnvelopedMessage,
 249    {
 250        let prev_handler = self.handlers.insert(
 251            TypeId::of::<M>(),
 252            Box::new(move |envelope, session| {
 253                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 254                let span = info_span!(
 255                    "handle message",
 256                    payload_type = envelope.payload_type_name()
 257                );
 258                span.in_scope(|| {
 259                    tracing::info!(
 260                        payload_type = envelope.payload_type_name(),
 261                        "message received"
 262                    );
 263                });
 264                let future = (handler)(*envelope, session);
 265                async move {
 266                    if let Err(error) = future.await {
 267                        tracing::error!(%error, "error handling message");
 268                    }
 269                }
 270                .instrument(span)
 271                .boxed()
 272            }),
 273        );
 274        if prev_handler.is_some() {
 275            panic!("registered a handler for the same message twice");
 276        }
 277        self
 278    }
 279
 280    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 281    where
 282        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 283        Fut: 'static + Send + Future<Output = Result<()>>,
 284        M: EnvelopedMessage,
 285    {
 286        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 287        self
 288    }
 289
 290    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 291    where
 292        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 293        Fut: Send + Future<Output = Result<()>>,
 294        M: RequestMessage,
 295    {
 296        let handler = Arc::new(handler);
 297        self.add_handler(move |envelope, session| {
 298            let receipt = envelope.receipt();
 299            let handler = handler.clone();
 300            async move {
 301                let peer = session.peer.clone();
 302                let responded = Arc::new(AtomicBool::default());
 303                let response = Response {
 304                    peer: peer.clone(),
 305                    responded: responded.clone(),
 306                    receipt,
 307                };
 308                match (handler)(envelope.payload, response, session).await {
 309                    Ok(()) => {
 310                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 311                            Ok(())
 312                        } else {
 313                            Err(anyhow!("handler did not send a response"))?
 314                        }
 315                    }
 316                    Err(error) => {
 317                        peer.respond_with_error(
 318                            receipt,
 319                            proto::Error {
 320                                message: error.to_string(),
 321                            },
 322                        )?;
 323                        Err(error)
 324                    }
 325                }
 326            }
 327        })
 328    }
 329
 330    pub fn handle_connection(
 331        self: &Arc<Self>,
 332        connection: Connection,
 333        address: String,
 334        user: User,
 335        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 336        executor: Executor,
 337    ) -> impl Future<Output = Result<()>> {
 338        let this = self.clone();
 339        let user_id = user.id;
 340        let login = user.github_login;
 341        let span = info_span!("handle connection", %user_id, %login, %address);
 342        let teardown = self.teardown.subscribe();
 343        async move {
 344            let (connection_id, handle_io, mut incoming_rx) = this
 345                .peer
 346                .add_connection(connection, {
 347                    let executor = executor.clone();
 348                    move |duration| executor.sleep(duration)
 349                });
 350
 351            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 352            this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
 353            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 354
 355            if let Some(send_connection_id) = send_connection_id.take() {
 356                let _ = send_connection_id.send(connection_id);
 357            }
 358
 359            if !user.connected_once {
 360                this.peer.send(connection_id, proto::ShowContacts {})?;
 361                this.app_state.db.set_user_connected_once(user_id, true).await?;
 362            }
 363
 364            let (contacts, invite_code) = future::try_join(
 365                this.app_state.db.get_contacts(user_id),
 366                this.app_state.db.get_invite_code_for_user(user_id)
 367            ).await?;
 368
 369            {
 370                let mut pool = this.connection_pool.lock().await;
 371                pool.add_connection(connection_id, user_id, user.admin);
 372                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 373
 374                if let Some((code, count)) = invite_code {
 375                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 376                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 377                        count: count as u32,
 378                    })?;
 379                }
 380            }
 381
 382            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 383                this.peer.send(connection_id, incoming_call)?;
 384            }
 385
 386            let session = Session {
 387                user_id,
 388                connection_id,
 389                db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))),
 390                peer: this.peer.clone(),
 391                connection_pool: this.connection_pool.clone(),
 392                live_kit_client: this.app_state.live_kit_client.clone()
 393            };
 394            update_user_contacts(user_id, &session).await?;
 395
 396            let handle_io = handle_io.fuse();
 397            futures::pin_mut!(handle_io);
 398
 399            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 400            // This prevents deadlocks when e.g., client A performs a request to client B and
 401            // client B performs a request to client A. If both clients stop processing further
 402            // messages until their respective request completes, they won't have a chance to
 403            // respond to the other client's request and cause a deadlock.
 404            //
 405            // This arrangement ensures we will attempt to process earlier messages first, but fall
 406            // back to processing messages arrived later in the spirit of making progress.
 407            let mut foreground_message_handlers = FuturesUnordered::new();
 408            loop {
 409                let next_message = incoming_rx.next().fuse();
 410                futures::pin_mut!(next_message);
 411                futures::select_biased! {
 412                    result = handle_io => {
 413                        if let Err(error) = result {
 414                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 415                        }
 416                        break;
 417                    }
 418                    _ = foreground_message_handlers.next() => {}
 419                    message = next_message => {
 420                        if let Some(message) = message {
 421                            let type_name = message.payload_type_name();
 422                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 423                            let span_enter = span.enter();
 424                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 425                                let is_background = message.is_background();
 426                                let handle_message = (handler)(message, session.clone());
 427                                drop(span_enter);
 428
 429                                let handle_message = handle_message.instrument(span);
 430                                if is_background {
 431                                    executor.spawn_detached(handle_message);
 432                                } else {
 433                                    foreground_message_handlers.push(handle_message);
 434                                }
 435                            } else {
 436                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 437                            }
 438                        } else {
 439                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 440                            break;
 441                        }
 442                    }
 443                }
 444            }
 445
 446            drop(foreground_message_handlers);
 447            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 448            if let Err(error) = sign_out(session, teardown, executor).await {
 449                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 450            }
 451
 452            Ok(())
 453        }.instrument(span)
 454    }
 455
 456    pub async fn invite_code_redeemed(
 457        self: &Arc<Self>,
 458        inviter_id: UserId,
 459        invitee_id: UserId,
 460    ) -> Result<()> {
 461        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 462            if let Some(code) = &user.invite_code {
 463                let pool = self.connection_pool.lock().await;
 464                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 465                for connection_id in pool.user_connection_ids(inviter_id) {
 466                    self.peer.send(
 467                        connection_id,
 468                        proto::UpdateContacts {
 469                            contacts: vec![invitee_contact.clone()],
 470                            ..Default::default()
 471                        },
 472                    )?;
 473                    self.peer.send(
 474                        connection_id,
 475                        proto::UpdateInviteInfo {
 476                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 477                            count: user.invite_count as u32,
 478                        },
 479                    )?;
 480                }
 481            }
 482        }
 483        Ok(())
 484    }
 485
 486    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 487        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 488            if let Some(invite_code) = &user.invite_code {
 489                let pool = self.connection_pool.lock().await;
 490                for connection_id in pool.user_connection_ids(user_id) {
 491                    self.peer.send(
 492                        connection_id,
 493                        proto::UpdateInviteInfo {
 494                            url: format!(
 495                                "{}{}",
 496                                self.app_state.config.invite_link_prefix, invite_code
 497                            ),
 498                            count: user.invite_count as u32,
 499                        },
 500                    )?;
 501                }
 502            }
 503        }
 504        Ok(())
 505    }
 506
 507    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 508        ServerSnapshot {
 509            connection_pool: ConnectionPoolGuard {
 510                guard: self.connection_pool.lock().await,
 511                _not_send: PhantomData,
 512            },
 513            peer: &self.peer,
 514        }
 515    }
 516}
 517
 518impl<'a> Deref for ConnectionPoolGuard<'a> {
 519    type Target = ConnectionPool;
 520
 521    fn deref(&self) -> &Self::Target {
 522        &*self.guard
 523    }
 524}
 525
 526impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 527    fn deref_mut(&mut self) -> &mut Self::Target {
 528        &mut *self.guard
 529    }
 530}
 531
 532impl<'a> Drop for ConnectionPoolGuard<'a> {
 533    fn drop(&mut self) {
 534        #[cfg(test)]
 535        self.check_invariants();
 536    }
 537}
 538
 539fn broadcast<F>(
 540    sender_id: ConnectionId,
 541    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 542    mut f: F,
 543) where
 544    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 545{
 546    for receiver_id in receiver_ids {
 547        if receiver_id != sender_id {
 548            f(receiver_id).trace_err();
 549        }
 550    }
 551}
 552
 553lazy_static! {
 554    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 555}
 556
 557pub struct ProtocolVersion(u32);
 558
 559impl Header for ProtocolVersion {
 560    fn name() -> &'static HeaderName {
 561        &ZED_PROTOCOL_VERSION
 562    }
 563
 564    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 565    where
 566        Self: Sized,
 567        I: Iterator<Item = &'i axum::http::HeaderValue>,
 568    {
 569        let version = values
 570            .next()
 571            .ok_or_else(axum::headers::Error::invalid)?
 572            .to_str()
 573            .map_err(|_| axum::headers::Error::invalid())?
 574            .parse()
 575            .map_err(|_| axum::headers::Error::invalid())?;
 576        Ok(Self(version))
 577    }
 578
 579    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 580        values.extend([self.0.to_string().parse().unwrap()]);
 581    }
 582}
 583
 584pub fn routes(server: Arc<Server>) -> Router<Body> {
 585    Router::new()
 586        .route("/rpc", get(handle_websocket_request))
 587        .layer(
 588            ServiceBuilder::new()
 589                .layer(Extension(server.app_state.clone()))
 590                .layer(middleware::from_fn(auth::validate_header)),
 591        )
 592        .route("/metrics", get(handle_metrics))
 593        .layer(Extension(server))
 594}
 595
 596pub async fn handle_websocket_request(
 597    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 598    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 599    Extension(server): Extension<Arc<Server>>,
 600    Extension(user): Extension<User>,
 601    ws: WebSocketUpgrade,
 602) -> axum::response::Response {
 603    if protocol_version != rpc::PROTOCOL_VERSION {
 604        return (
 605            StatusCode::UPGRADE_REQUIRED,
 606            "client must be upgraded".to_string(),
 607        )
 608            .into_response();
 609    }
 610    let socket_address = socket_address.to_string();
 611    ws.on_upgrade(move |socket| {
 612        use util::ResultExt;
 613        let socket = socket
 614            .map_ok(to_tungstenite_message)
 615            .err_into()
 616            .with(|message| async move { Ok(to_axum_message(message)) });
 617        let connection = Connection::new(Box::pin(socket));
 618        async move {
 619            server
 620                .handle_connection(connection, socket_address, user, None, Executor::Production)
 621                .await
 622                .log_err();
 623        }
 624    })
 625}
 626
 627pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 628    let connections = server
 629        .connection_pool
 630        .lock()
 631        .await
 632        .connections()
 633        .filter(|connection| !connection.admin)
 634        .count();
 635
 636    METRIC_CONNECTIONS.set(connections as _);
 637
 638    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 639    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 640
 641    let encoder = prometheus::TextEncoder::new();
 642    let metric_families = prometheus::gather();
 643    let encoded_metrics = encoder
 644        .encode_to_string(&metric_families)
 645        .map_err(|err| anyhow!("{}", err))?;
 646    Ok(encoded_metrics)
 647}
 648
 649#[instrument(err, skip(executor))]
 650async fn sign_out(
 651    session: Session,
 652    mut teardown: watch::Receiver<()>,
 653    executor: Executor,
 654) -> Result<()> {
 655    session.peer.disconnect(session.connection_id);
 656    session
 657        .connection_pool()
 658        .await
 659        .remove_connection(session.connection_id)?;
 660
 661    if let Some(mut left_projects) = session
 662        .db()
 663        .await
 664        .connection_lost(session.connection_id)
 665        .await
 666        .trace_err()
 667    {
 668        for left_project in mem::take(&mut *left_projects) {
 669            project_left(&left_project, &session);
 670        }
 671    }
 672
 673    futures::select_biased! {
 674        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
 675            leave_room_for_session(&session).await.trace_err();
 676
 677            if !session
 678                .connection_pool()
 679                .await
 680                .is_user_online(session.user_id)
 681            {
 682                let db = session.db().await;
 683                if let Some(room) = db.decline_call(None, session.user_id).await.trace_err() {
 684                    room_updated(&room, &session);
 685                }
 686            }
 687            update_user_contacts(session.user_id, &session).await?;
 688        }
 689        _ = teardown.changed().fuse() => {}
 690    }
 691
 692    Ok(())
 693}
 694
 695async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 696    response.send(proto::Ack {})?;
 697    Ok(())
 698}
 699
 700async fn create_room(
 701    _request: proto::CreateRoom,
 702    response: Response<proto::CreateRoom>,
 703    session: Session,
 704) -> Result<()> {
 705    let live_kit_room = nanoid::nanoid!(30);
 706    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 707        if let Some(_) = live_kit
 708            .create_room(live_kit_room.clone())
 709            .await
 710            .trace_err()
 711        {
 712            if let Some(token) = live_kit
 713                .room_token(&live_kit_room, &session.connection_id.to_string())
 714                .trace_err()
 715            {
 716                Some(proto::LiveKitConnectionInfo {
 717                    server_url: live_kit.url().into(),
 718                    token,
 719                })
 720            } else {
 721                None
 722            }
 723        } else {
 724            None
 725        }
 726    } else {
 727        None
 728    };
 729
 730    {
 731        let room = session
 732            .db()
 733            .await
 734            .create_room(session.user_id, session.connection_id, &live_kit_room)
 735            .await?;
 736
 737        response.send(proto::CreateRoomResponse {
 738            room: Some(room.clone()),
 739            live_kit_connection_info,
 740        })?;
 741    }
 742
 743    update_user_contacts(session.user_id, &session).await?;
 744    Ok(())
 745}
 746
 747async fn join_room(
 748    request: proto::JoinRoom,
 749    response: Response<proto::JoinRoom>,
 750    session: Session,
 751) -> Result<()> {
 752    let room = {
 753        let room = session
 754            .db()
 755            .await
 756            .join_room(
 757                RoomId::from_proto(request.id),
 758                session.user_id,
 759                session.connection_id,
 760            )
 761            .await?;
 762        room_updated(&room, &session);
 763        room.clone()
 764    };
 765
 766    for connection_id in session
 767        .connection_pool()
 768        .await
 769        .user_connection_ids(session.user_id)
 770    {
 771        session
 772            .peer
 773            .send(connection_id, proto::CallCanceled {})
 774            .trace_err();
 775    }
 776
 777    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 778        if let Some(token) = live_kit
 779            .room_token(&room.live_kit_room, &session.connection_id.to_string())
 780            .trace_err()
 781        {
 782            Some(proto::LiveKitConnectionInfo {
 783                server_url: live_kit.url().into(),
 784                token,
 785            })
 786        } else {
 787            None
 788        }
 789    } else {
 790        None
 791    };
 792
 793    response.send(proto::JoinRoomResponse {
 794        room: Some(room),
 795        live_kit_connection_info,
 796    })?;
 797
 798    update_user_contacts(session.user_id, &session).await?;
 799    Ok(())
 800}
 801
 802async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
 803    leave_room_for_session(&session).await
 804}
 805
 806async fn call(
 807    request: proto::Call,
 808    response: Response<proto::Call>,
 809    session: Session,
 810) -> Result<()> {
 811    let room_id = RoomId::from_proto(request.room_id);
 812    let calling_user_id = session.user_id;
 813    let calling_connection_id = session.connection_id;
 814    let called_user_id = UserId::from_proto(request.called_user_id);
 815    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
 816    if !session
 817        .db()
 818        .await
 819        .has_contact(calling_user_id, called_user_id)
 820        .await?
 821    {
 822        return Err(anyhow!("cannot call a user who isn't a contact"))?;
 823    }
 824
 825    let incoming_call = {
 826        let (room, incoming_call) = &mut *session
 827            .db()
 828            .await
 829            .call(
 830                room_id,
 831                calling_user_id,
 832                calling_connection_id,
 833                called_user_id,
 834                initial_project_id,
 835            )
 836            .await?;
 837        room_updated(&room, &session);
 838        mem::take(incoming_call)
 839    };
 840    update_user_contacts(called_user_id, &session).await?;
 841
 842    let mut calls = session
 843        .connection_pool()
 844        .await
 845        .user_connection_ids(called_user_id)
 846        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
 847        .collect::<FuturesUnordered<_>>();
 848
 849    while let Some(call_response) = calls.next().await {
 850        match call_response.as_ref() {
 851            Ok(_) => {
 852                response.send(proto::Ack {})?;
 853                return Ok(());
 854            }
 855            Err(_) => {
 856                call_response.trace_err();
 857            }
 858        }
 859    }
 860
 861    {
 862        let room = session
 863            .db()
 864            .await
 865            .call_failed(room_id, called_user_id)
 866            .await?;
 867        room_updated(&room, &session);
 868    }
 869    update_user_contacts(called_user_id, &session).await?;
 870
 871    Err(anyhow!("failed to ring user"))?
 872}
 873
 874async fn cancel_call(
 875    request: proto::CancelCall,
 876    response: Response<proto::CancelCall>,
 877    session: Session,
 878) -> Result<()> {
 879    let called_user_id = UserId::from_proto(request.called_user_id);
 880    let room_id = RoomId::from_proto(request.room_id);
 881    {
 882        let room = session
 883            .db()
 884            .await
 885            .cancel_call(Some(room_id), session.connection_id, called_user_id)
 886            .await?;
 887        room_updated(&room, &session);
 888    }
 889
 890    for connection_id in session
 891        .connection_pool()
 892        .await
 893        .user_connection_ids(called_user_id)
 894    {
 895        session
 896            .peer
 897            .send(connection_id, proto::CallCanceled {})
 898            .trace_err();
 899    }
 900    response.send(proto::Ack {})?;
 901
 902    update_user_contacts(called_user_id, &session).await?;
 903    Ok(())
 904}
 905
 906async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
 907    let room_id = RoomId::from_proto(message.room_id);
 908    {
 909        let room = session
 910            .db()
 911            .await
 912            .decline_call(Some(room_id), session.user_id)
 913            .await?;
 914        room_updated(&room, &session);
 915    }
 916
 917    for connection_id in session
 918        .connection_pool()
 919        .await
 920        .user_connection_ids(session.user_id)
 921    {
 922        session
 923            .peer
 924            .send(connection_id, proto::CallCanceled {})
 925            .trace_err();
 926    }
 927    update_user_contacts(session.user_id, &session).await?;
 928    Ok(())
 929}
 930
 931async fn update_participant_location(
 932    request: proto::UpdateParticipantLocation,
 933    response: Response<proto::UpdateParticipantLocation>,
 934    session: Session,
 935) -> Result<()> {
 936    let room_id = RoomId::from_proto(request.room_id);
 937    let location = request
 938        .location
 939        .ok_or_else(|| anyhow!("invalid location"))?;
 940    let room = session
 941        .db()
 942        .await
 943        .update_room_participant_location(room_id, session.connection_id, location)
 944        .await?;
 945    room_updated(&room, &session);
 946    response.send(proto::Ack {})?;
 947    Ok(())
 948}
 949
 950async fn share_project(
 951    request: proto::ShareProject,
 952    response: Response<proto::ShareProject>,
 953    session: Session,
 954) -> Result<()> {
 955    let (project_id, room) = &*session
 956        .db()
 957        .await
 958        .share_project(
 959            RoomId::from_proto(request.room_id),
 960            session.connection_id,
 961            &request.worktrees,
 962        )
 963        .await?;
 964    response.send(proto::ShareProjectResponse {
 965        project_id: project_id.to_proto(),
 966    })?;
 967    room_updated(&room, &session);
 968
 969    Ok(())
 970}
 971
 972async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
 973    let project_id = ProjectId::from_proto(message.project_id);
 974
 975    let (room, guest_connection_ids) = &*session
 976        .db()
 977        .await
 978        .unshare_project(project_id, session.connection_id)
 979        .await?;
 980
 981    broadcast(
 982        session.connection_id,
 983        guest_connection_ids.iter().copied(),
 984        |conn_id| session.peer.send(conn_id, message.clone()),
 985    );
 986    room_updated(&room, &session);
 987
 988    Ok(())
 989}
 990
 991async fn join_project(
 992    request: proto::JoinProject,
 993    response: Response<proto::JoinProject>,
 994    session: Session,
 995) -> Result<()> {
 996    let project_id = ProjectId::from_proto(request.project_id);
 997    let guest_user_id = session.user_id;
 998
 999    tracing::info!(%project_id, "join project");
1000
1001    let (project, replica_id) = &mut *session
1002        .db()
1003        .await
1004        .join_project(project_id, session.connection_id)
1005        .await?;
1006
1007    let collaborators = project
1008        .collaborators
1009        .iter()
1010        .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
1011        .map(|collaborator| proto::Collaborator {
1012            peer_id: collaborator.connection_id as u32,
1013            replica_id: collaborator.replica_id.0 as u32,
1014            user_id: collaborator.user_id.to_proto(),
1015        })
1016        .collect::<Vec<_>>();
1017    let worktrees = project
1018        .worktrees
1019        .iter()
1020        .map(|(id, worktree)| proto::WorktreeMetadata {
1021            id: *id,
1022            root_name: worktree.root_name.clone(),
1023            visible: worktree.visible,
1024            abs_path: worktree.abs_path.clone(),
1025        })
1026        .collect::<Vec<_>>();
1027
1028    for collaborator in &collaborators {
1029        session
1030            .peer
1031            .send(
1032                ConnectionId(collaborator.peer_id),
1033                proto::AddProjectCollaborator {
1034                    project_id: project_id.to_proto(),
1035                    collaborator: Some(proto::Collaborator {
1036                        peer_id: session.connection_id.0,
1037                        replica_id: replica_id.0 as u32,
1038                        user_id: guest_user_id.to_proto(),
1039                    }),
1040                },
1041            )
1042            .trace_err();
1043    }
1044
1045    // First, we send the metadata associated with each worktree.
1046    response.send(proto::JoinProjectResponse {
1047        worktrees: worktrees.clone(),
1048        replica_id: replica_id.0 as u32,
1049        collaborators: collaborators.clone(),
1050        language_servers: project.language_servers.clone(),
1051    })?;
1052
1053    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1054        #[cfg(any(test, feature = "test-support"))]
1055        const MAX_CHUNK_SIZE: usize = 2;
1056        #[cfg(not(any(test, feature = "test-support")))]
1057        const MAX_CHUNK_SIZE: usize = 256;
1058
1059        // Stream this worktree's entries.
1060        let message = proto::UpdateWorktree {
1061            project_id: project_id.to_proto(),
1062            worktree_id,
1063            abs_path: worktree.abs_path.clone(),
1064            root_name: worktree.root_name,
1065            updated_entries: worktree.entries,
1066            removed_entries: Default::default(),
1067            scan_id: worktree.scan_id,
1068            is_last_update: worktree.is_complete,
1069        };
1070        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1071            session.peer.send(session.connection_id, update.clone())?;
1072        }
1073
1074        // Stream this worktree's diagnostics.
1075        for summary in worktree.diagnostic_summaries {
1076            session.peer.send(
1077                session.connection_id,
1078                proto::UpdateDiagnosticSummary {
1079                    project_id: project_id.to_proto(),
1080                    worktree_id: worktree.id,
1081                    summary: Some(summary),
1082                },
1083            )?;
1084        }
1085    }
1086
1087    for language_server in &project.language_servers {
1088        session.peer.send(
1089            session.connection_id,
1090            proto::UpdateLanguageServer {
1091                project_id: project_id.to_proto(),
1092                language_server_id: language_server.id,
1093                variant: Some(
1094                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1095                        proto::LspDiskBasedDiagnosticsUpdated {},
1096                    ),
1097                ),
1098            },
1099        )?;
1100    }
1101
1102    Ok(())
1103}
1104
1105async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1106    let sender_id = session.connection_id;
1107    let project_id = ProjectId::from_proto(request.project_id);
1108
1109    let project = session
1110        .db()
1111        .await
1112        .leave_project(project_id, sender_id)
1113        .await?;
1114    tracing::info!(
1115        %project_id,
1116        host_user_id = %project.host_user_id,
1117        host_connection_id = %project.host_connection_id,
1118        "leave project"
1119    );
1120    project_left(&project, &session);
1121
1122    Ok(())
1123}
1124
1125async fn update_project(
1126    request: proto::UpdateProject,
1127    response: Response<proto::UpdateProject>,
1128    session: Session,
1129) -> Result<()> {
1130    let project_id = ProjectId::from_proto(request.project_id);
1131    let (room, guest_connection_ids) = &*session
1132        .db()
1133        .await
1134        .update_project(project_id, session.connection_id, &request.worktrees)
1135        .await?;
1136    broadcast(
1137        session.connection_id,
1138        guest_connection_ids.iter().copied(),
1139        |connection_id| {
1140            session
1141                .peer
1142                .forward_send(session.connection_id, connection_id, request.clone())
1143        },
1144    );
1145    room_updated(&room, &session);
1146    response.send(proto::Ack {})?;
1147
1148    Ok(())
1149}
1150
1151async fn update_worktree(
1152    request: proto::UpdateWorktree,
1153    response: Response<proto::UpdateWorktree>,
1154    session: Session,
1155) -> Result<()> {
1156    let guest_connection_ids = session
1157        .db()
1158        .await
1159        .update_worktree(&request, session.connection_id)
1160        .await?;
1161
1162    broadcast(
1163        session.connection_id,
1164        guest_connection_ids.iter().copied(),
1165        |connection_id| {
1166            session
1167                .peer
1168                .forward_send(session.connection_id, connection_id, request.clone())
1169        },
1170    );
1171    response.send(proto::Ack {})?;
1172    Ok(())
1173}
1174
1175async fn update_diagnostic_summary(
1176    message: proto::UpdateDiagnosticSummary,
1177    session: Session,
1178) -> Result<()> {
1179    let guest_connection_ids = session
1180        .db()
1181        .await
1182        .update_diagnostic_summary(&message, session.connection_id)
1183        .await?;
1184
1185    broadcast(
1186        session.connection_id,
1187        guest_connection_ids.iter().copied(),
1188        |connection_id| {
1189            session
1190                .peer
1191                .forward_send(session.connection_id, connection_id, message.clone())
1192        },
1193    );
1194
1195    Ok(())
1196}
1197
1198async fn start_language_server(
1199    request: proto::StartLanguageServer,
1200    session: Session,
1201) -> Result<()> {
1202    let guest_connection_ids = session
1203        .db()
1204        .await
1205        .start_language_server(&request, session.connection_id)
1206        .await?;
1207
1208    broadcast(
1209        session.connection_id,
1210        guest_connection_ids.iter().copied(),
1211        |connection_id| {
1212            session
1213                .peer
1214                .forward_send(session.connection_id, connection_id, request.clone())
1215        },
1216    );
1217    Ok(())
1218}
1219
1220async fn update_language_server(
1221    request: proto::UpdateLanguageServer,
1222    session: Session,
1223) -> Result<()> {
1224    let project_id = ProjectId::from_proto(request.project_id);
1225    let project_connection_ids = session
1226        .db()
1227        .await
1228        .project_connection_ids(project_id, session.connection_id)
1229        .await?;
1230    broadcast(
1231        session.connection_id,
1232        project_connection_ids.iter().copied(),
1233        |connection_id| {
1234            session
1235                .peer
1236                .forward_send(session.connection_id, connection_id, request.clone())
1237        },
1238    );
1239    Ok(())
1240}
1241
1242async fn forward_project_request<T>(
1243    request: T,
1244    response: Response<T>,
1245    session: Session,
1246) -> Result<()>
1247where
1248    T: EntityMessage + RequestMessage,
1249{
1250    let project_id = ProjectId::from_proto(request.remote_entity_id());
1251    let host_connection_id = {
1252        let collaborators = session
1253            .db()
1254            .await
1255            .project_collaborators(project_id, session.connection_id)
1256            .await?;
1257        ConnectionId(
1258            collaborators
1259                .iter()
1260                .find(|collaborator| collaborator.is_host)
1261                .ok_or_else(|| anyhow!("host not found"))?
1262                .connection_id as u32,
1263        )
1264    };
1265
1266    let payload = session
1267        .peer
1268        .forward_request(session.connection_id, host_connection_id, request)
1269        .await?;
1270
1271    response.send(payload)?;
1272    Ok(())
1273}
1274
1275async fn save_buffer(
1276    request: proto::SaveBuffer,
1277    response: Response<proto::SaveBuffer>,
1278    session: Session,
1279) -> Result<()> {
1280    let project_id = ProjectId::from_proto(request.project_id);
1281    let host_connection_id = {
1282        let collaborators = session
1283            .db()
1284            .await
1285            .project_collaborators(project_id, session.connection_id)
1286            .await?;
1287        let host = collaborators
1288            .iter()
1289            .find(|collaborator| collaborator.is_host)
1290            .ok_or_else(|| anyhow!("host not found"))?;
1291        ConnectionId(host.connection_id as u32)
1292    };
1293    let response_payload = session
1294        .peer
1295        .forward_request(session.connection_id, host_connection_id, request.clone())
1296        .await?;
1297
1298    let mut collaborators = session
1299        .db()
1300        .await
1301        .project_collaborators(project_id, session.connection_id)
1302        .await?;
1303    collaborators
1304        .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1305    let project_connection_ids = collaborators
1306        .iter()
1307        .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1308    broadcast(host_connection_id, project_connection_ids, |conn_id| {
1309        session
1310            .peer
1311            .forward_send(host_connection_id, conn_id, response_payload.clone())
1312    });
1313    response.send(response_payload)?;
1314    Ok(())
1315}
1316
1317async fn create_buffer_for_peer(
1318    request: proto::CreateBufferForPeer,
1319    session: Session,
1320) -> Result<()> {
1321    session.peer.forward_send(
1322        session.connection_id,
1323        ConnectionId(request.peer_id),
1324        request,
1325    )?;
1326    Ok(())
1327}
1328
1329async fn update_buffer(
1330    request: proto::UpdateBuffer,
1331    response: Response<proto::UpdateBuffer>,
1332    session: Session,
1333) -> Result<()> {
1334    let project_id = ProjectId::from_proto(request.project_id);
1335    let project_connection_ids = session
1336        .db()
1337        .await
1338        .project_connection_ids(project_id, session.connection_id)
1339        .await?;
1340
1341    broadcast(
1342        session.connection_id,
1343        project_connection_ids.iter().copied(),
1344        |connection_id| {
1345            session
1346                .peer
1347                .forward_send(session.connection_id, connection_id, request.clone())
1348        },
1349    );
1350    response.send(proto::Ack {})?;
1351    Ok(())
1352}
1353
1354async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1355    let project_id = ProjectId::from_proto(request.project_id);
1356    let project_connection_ids = session
1357        .db()
1358        .await
1359        .project_connection_ids(project_id, session.connection_id)
1360        .await?;
1361
1362    broadcast(
1363        session.connection_id,
1364        project_connection_ids.iter().copied(),
1365        |connection_id| {
1366            session
1367                .peer
1368                .forward_send(session.connection_id, connection_id, request.clone())
1369        },
1370    );
1371    Ok(())
1372}
1373
1374async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1375    let project_id = ProjectId::from_proto(request.project_id);
1376    let project_connection_ids = session
1377        .db()
1378        .await
1379        .project_connection_ids(project_id, session.connection_id)
1380        .await?;
1381    broadcast(
1382        session.connection_id,
1383        project_connection_ids.iter().copied(),
1384        |connection_id| {
1385            session
1386                .peer
1387                .forward_send(session.connection_id, connection_id, request.clone())
1388        },
1389    );
1390    Ok(())
1391}
1392
1393async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1394    let project_id = ProjectId::from_proto(request.project_id);
1395    let project_connection_ids = session
1396        .db()
1397        .await
1398        .project_connection_ids(project_id, session.connection_id)
1399        .await?;
1400    broadcast(
1401        session.connection_id,
1402        project_connection_ids.iter().copied(),
1403        |connection_id| {
1404            session
1405                .peer
1406                .forward_send(session.connection_id, connection_id, request.clone())
1407        },
1408    );
1409    Ok(())
1410}
1411
1412async fn follow(
1413    request: proto::Follow,
1414    response: Response<proto::Follow>,
1415    session: Session,
1416) -> Result<()> {
1417    let project_id = ProjectId::from_proto(request.project_id);
1418    let leader_id = ConnectionId(request.leader_id);
1419    let follower_id = session.connection_id;
1420    {
1421        let project_connection_ids = session
1422            .db()
1423            .await
1424            .project_connection_ids(project_id, session.connection_id)
1425            .await?;
1426
1427        if !project_connection_ids.contains(&leader_id) {
1428            Err(anyhow!("no such peer"))?;
1429        }
1430    }
1431
1432    let mut response_payload = session
1433        .peer
1434        .forward_request(session.connection_id, leader_id, request)
1435        .await?;
1436    response_payload
1437        .views
1438        .retain(|view| view.leader_id != Some(follower_id.0));
1439    response.send(response_payload)?;
1440    Ok(())
1441}
1442
1443async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1444    let project_id = ProjectId::from_proto(request.project_id);
1445    let leader_id = ConnectionId(request.leader_id);
1446    let project_connection_ids = session
1447        .db()
1448        .await
1449        .project_connection_ids(project_id, session.connection_id)
1450        .await?;
1451    if !project_connection_ids.contains(&leader_id) {
1452        Err(anyhow!("no such peer"))?;
1453    }
1454    session
1455        .peer
1456        .forward_send(session.connection_id, leader_id, request)?;
1457    Ok(())
1458}
1459
1460async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1461    let project_id = ProjectId::from_proto(request.project_id);
1462    let project_connection_ids = session
1463        .db
1464        .lock()
1465        .await
1466        .project_connection_ids(project_id, session.connection_id)
1467        .await?;
1468
1469    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1470        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1471        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1472        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1473    });
1474    for follower_id in &request.follower_ids {
1475        let follower_id = ConnectionId(*follower_id);
1476        if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1477            session
1478                .peer
1479                .forward_send(session.connection_id, follower_id, request.clone())?;
1480        }
1481    }
1482    Ok(())
1483}
1484
1485async fn get_users(
1486    request: proto::GetUsers,
1487    response: Response<proto::GetUsers>,
1488    session: Session,
1489) -> Result<()> {
1490    let user_ids = request
1491        .user_ids
1492        .into_iter()
1493        .map(UserId::from_proto)
1494        .collect();
1495    let users = session
1496        .db()
1497        .await
1498        .get_users_by_ids(user_ids)
1499        .await?
1500        .into_iter()
1501        .map(|user| proto::User {
1502            id: user.id.to_proto(),
1503            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1504            github_login: user.github_login,
1505        })
1506        .collect();
1507    response.send(proto::UsersResponse { users })?;
1508    Ok(())
1509}
1510
1511async fn fuzzy_search_users(
1512    request: proto::FuzzySearchUsers,
1513    response: Response<proto::FuzzySearchUsers>,
1514    session: Session,
1515) -> Result<()> {
1516    let query = request.query;
1517    let users = match query.len() {
1518        0 => vec![],
1519        1 | 2 => session
1520            .db()
1521            .await
1522            .get_user_by_github_account(&query, None)
1523            .await?
1524            .into_iter()
1525            .collect(),
1526        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1527    };
1528    let users = users
1529        .into_iter()
1530        .filter(|user| user.id != session.user_id)
1531        .map(|user| proto::User {
1532            id: user.id.to_proto(),
1533            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1534            github_login: user.github_login,
1535        })
1536        .collect();
1537    response.send(proto::UsersResponse { users })?;
1538    Ok(())
1539}
1540
1541async fn request_contact(
1542    request: proto::RequestContact,
1543    response: Response<proto::RequestContact>,
1544    session: Session,
1545) -> Result<()> {
1546    let requester_id = session.user_id;
1547    let responder_id = UserId::from_proto(request.responder_id);
1548    if requester_id == responder_id {
1549        return Err(anyhow!("cannot add yourself as a contact"))?;
1550    }
1551
1552    session
1553        .db()
1554        .await
1555        .send_contact_request(requester_id, responder_id)
1556        .await?;
1557
1558    // Update outgoing contact requests of requester
1559    let mut update = proto::UpdateContacts::default();
1560    update.outgoing_requests.push(responder_id.to_proto());
1561    for connection_id in session
1562        .connection_pool()
1563        .await
1564        .user_connection_ids(requester_id)
1565    {
1566        session.peer.send(connection_id, update.clone())?;
1567    }
1568
1569    // Update incoming contact requests of responder
1570    let mut update = proto::UpdateContacts::default();
1571    update
1572        .incoming_requests
1573        .push(proto::IncomingContactRequest {
1574            requester_id: requester_id.to_proto(),
1575            should_notify: true,
1576        });
1577    for connection_id in session
1578        .connection_pool()
1579        .await
1580        .user_connection_ids(responder_id)
1581    {
1582        session.peer.send(connection_id, update.clone())?;
1583    }
1584
1585    response.send(proto::Ack {})?;
1586    Ok(())
1587}
1588
1589async fn respond_to_contact_request(
1590    request: proto::RespondToContactRequest,
1591    response: Response<proto::RespondToContactRequest>,
1592    session: Session,
1593) -> Result<()> {
1594    let responder_id = session.user_id;
1595    let requester_id = UserId::from_proto(request.requester_id);
1596    let db = session.db().await;
1597    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1598        db.dismiss_contact_notification(responder_id, requester_id)
1599            .await?;
1600    } else {
1601        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1602
1603        db.respond_to_contact_request(responder_id, requester_id, accept)
1604            .await?;
1605        let requester_busy = db.is_user_busy(requester_id).await?;
1606        let responder_busy = db.is_user_busy(responder_id).await?;
1607
1608        let pool = session.connection_pool().await;
1609        // Update responder with new contact
1610        let mut update = proto::UpdateContacts::default();
1611        if accept {
1612            update
1613                .contacts
1614                .push(contact_for_user(requester_id, false, requester_busy, &pool));
1615        }
1616        update
1617            .remove_incoming_requests
1618            .push(requester_id.to_proto());
1619        for connection_id in pool.user_connection_ids(responder_id) {
1620            session.peer.send(connection_id, update.clone())?;
1621        }
1622
1623        // Update requester with new contact
1624        let mut update = proto::UpdateContacts::default();
1625        if accept {
1626            update
1627                .contacts
1628                .push(contact_for_user(responder_id, true, responder_busy, &pool));
1629        }
1630        update
1631            .remove_outgoing_requests
1632            .push(responder_id.to_proto());
1633        for connection_id in pool.user_connection_ids(requester_id) {
1634            session.peer.send(connection_id, update.clone())?;
1635        }
1636    }
1637
1638    response.send(proto::Ack {})?;
1639    Ok(())
1640}
1641
1642async fn remove_contact(
1643    request: proto::RemoveContact,
1644    response: Response<proto::RemoveContact>,
1645    session: Session,
1646) -> Result<()> {
1647    let requester_id = session.user_id;
1648    let responder_id = UserId::from_proto(request.user_id);
1649    let db = session.db().await;
1650    db.remove_contact(requester_id, responder_id).await?;
1651
1652    let pool = session.connection_pool().await;
1653    // Update outgoing contact requests of requester
1654    let mut update = proto::UpdateContacts::default();
1655    update
1656        .remove_outgoing_requests
1657        .push(responder_id.to_proto());
1658    for connection_id in pool.user_connection_ids(requester_id) {
1659        session.peer.send(connection_id, update.clone())?;
1660    }
1661
1662    // Update incoming contact requests of responder
1663    let mut update = proto::UpdateContacts::default();
1664    update
1665        .remove_incoming_requests
1666        .push(requester_id.to_proto());
1667    for connection_id in pool.user_connection_ids(responder_id) {
1668        session.peer.send(connection_id, update.clone())?;
1669    }
1670
1671    response.send(proto::Ack {})?;
1672    Ok(())
1673}
1674
1675async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1676    let project_id = ProjectId::from_proto(request.project_id);
1677    let project_connection_ids = session
1678        .db()
1679        .await
1680        .project_connection_ids(project_id, session.connection_id)
1681        .await?;
1682    broadcast(
1683        session.connection_id,
1684        project_connection_ids.iter().copied(),
1685        |connection_id| {
1686            session
1687                .peer
1688                .forward_send(session.connection_id, connection_id, request.clone())
1689        },
1690    );
1691    Ok(())
1692}
1693
1694async fn get_private_user_info(
1695    _request: proto::GetPrivateUserInfo,
1696    response: Response<proto::GetPrivateUserInfo>,
1697    session: Session,
1698) -> Result<()> {
1699    let metrics_id = session
1700        .db()
1701        .await
1702        .get_user_metrics_id(session.user_id)
1703        .await?;
1704    let user = session
1705        .db()
1706        .await
1707        .get_user_by_id(session.user_id)
1708        .await?
1709        .ok_or_else(|| anyhow!("user not found"))?;
1710    response.send(proto::GetPrivateUserInfoResponse {
1711        metrics_id,
1712        staff: user.admin,
1713    })?;
1714    Ok(())
1715}
1716
1717fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1718    match message {
1719        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1720        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1721        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1722        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1723        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1724            code: frame.code.into(),
1725            reason: frame.reason,
1726        })),
1727    }
1728}
1729
1730fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1731    match message {
1732        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1733        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1734        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1735        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1736        AxumMessage::Close(frame) => {
1737            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1738                code: frame.code.into(),
1739                reason: frame.reason,
1740            }))
1741        }
1742    }
1743}
1744
1745fn build_initial_contacts_update(
1746    contacts: Vec<db::Contact>,
1747    pool: &ConnectionPool,
1748) -> proto::UpdateContacts {
1749    let mut update = proto::UpdateContacts::default();
1750
1751    for contact in contacts {
1752        match contact {
1753            db::Contact::Accepted {
1754                user_id,
1755                should_notify,
1756                busy,
1757            } => {
1758                update
1759                    .contacts
1760                    .push(contact_for_user(user_id, should_notify, busy, &pool));
1761            }
1762            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1763            db::Contact::Incoming {
1764                user_id,
1765                should_notify,
1766            } => update
1767                .incoming_requests
1768                .push(proto::IncomingContactRequest {
1769                    requester_id: user_id.to_proto(),
1770                    should_notify,
1771                }),
1772        }
1773    }
1774
1775    update
1776}
1777
1778fn contact_for_user(
1779    user_id: UserId,
1780    should_notify: bool,
1781    busy: bool,
1782    pool: &ConnectionPool,
1783) -> proto::Contact {
1784    proto::Contact {
1785        user_id: user_id.to_proto(),
1786        online: pool.is_user_online(user_id),
1787        busy,
1788        should_notify,
1789    }
1790}
1791
1792fn room_updated(room: &proto::Room, session: &Session) {
1793    for participant in &room.participants {
1794        session
1795            .peer
1796            .send(
1797                ConnectionId(participant.peer_id),
1798                proto::RoomUpdated {
1799                    room: Some(room.clone()),
1800                },
1801            )
1802            .trace_err();
1803    }
1804}
1805
1806async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1807    let db = session.db().await;
1808    let contacts = db.get_contacts(user_id).await?;
1809    let busy = db.is_user_busy(user_id).await?;
1810
1811    let pool = session.connection_pool().await;
1812    let updated_contact = contact_for_user(user_id, false, busy, &pool);
1813    for contact in contacts {
1814        if let db::Contact::Accepted {
1815            user_id: contact_user_id,
1816            ..
1817        } = contact
1818        {
1819            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1820                session
1821                    .peer
1822                    .send(
1823                        contact_conn_id,
1824                        proto::UpdateContacts {
1825                            contacts: vec![updated_contact.clone()],
1826                            remove_contacts: Default::default(),
1827                            incoming_requests: Default::default(),
1828                            remove_incoming_requests: Default::default(),
1829                            outgoing_requests: Default::default(),
1830                            remove_outgoing_requests: Default::default(),
1831                        },
1832                    )
1833                    .trace_err();
1834            }
1835        }
1836    }
1837    Ok(())
1838}
1839
1840async fn leave_room_for_session(session: &Session) -> Result<()> {
1841    let mut contacts_to_update = HashSet::default();
1842
1843    let canceled_calls_to_user_ids;
1844    let live_kit_room;
1845    let delete_live_kit_room;
1846    {
1847        let mut left_room = session.db().await.leave_room(session.connection_id).await?;
1848        contacts_to_update.insert(session.user_id);
1849
1850        for project in left_room.left_projects.values() {
1851            project_left(project, session);
1852        }
1853
1854        room_updated(&left_room.room, &session);
1855        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
1856        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
1857        delete_live_kit_room = left_room.room.participants.is_empty();
1858    }
1859
1860    {
1861        let pool = session.connection_pool().await;
1862        for canceled_user_id in canceled_calls_to_user_ids {
1863            for connection_id in pool.user_connection_ids(canceled_user_id) {
1864                session
1865                    .peer
1866                    .send(connection_id, proto::CallCanceled {})
1867                    .trace_err();
1868            }
1869            contacts_to_update.insert(canceled_user_id);
1870        }
1871    }
1872
1873    for contact_user_id in contacts_to_update {
1874        update_user_contacts(contact_user_id, &session).await?;
1875    }
1876
1877    if let Some(live_kit) = session.live_kit_client.as_ref() {
1878        live_kit
1879            .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
1880            .await
1881            .trace_err();
1882
1883        if delete_live_kit_room {
1884            live_kit.delete_room(live_kit_room).await.trace_err();
1885        }
1886    }
1887
1888    Ok(())
1889}
1890
1891fn project_left(project: &db::LeftProject, session: &Session) {
1892    for connection_id in &project.connection_ids {
1893        if project.host_user_id == session.user_id {
1894            session
1895                .peer
1896                .send(
1897                    *connection_id,
1898                    proto::UnshareProject {
1899                        project_id: project.id.to_proto(),
1900                    },
1901                )
1902                .trace_err();
1903        } else {
1904            session
1905                .peer
1906                .send(
1907                    *connection_id,
1908                    proto::RemoveProjectCollaborator {
1909                        project_id: project.id.to_proto(),
1910                        peer_id: session.connection_id.0,
1911                    },
1912                )
1913                .trace_err();
1914        }
1915    }
1916
1917    session
1918        .peer
1919        .send(
1920            session.connection_id,
1921            proto::UnshareProject {
1922                project_id: project.id.to_proto(),
1923            },
1924        )
1925        .trace_err();
1926}
1927
1928pub trait ResultExt {
1929    type Ok;
1930
1931    fn trace_err(self) -> Option<Self::Ok>;
1932}
1933
1934impl<T, E> ResultExt for Result<T, E>
1935where
1936    E: std::fmt::Debug,
1937{
1938    type Ok = T;
1939
1940    fn trace_err(self) -> Option<T> {
1941        match self {
1942            Ok(value) => Some(value),
1943            Err(error) => {
1944                tracing::error!("{:?}", error);
1945                None
1946            }
1947        }
1948    }
1949}