rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, DefaultDb, ProjectId, RoomId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::{HashMap, HashSet};
  26pub use connection_pool::ConnectionPool;
  27use futures::{
  28    channel::oneshot,
  29    future::{self, BoxFuture},
  30    stream::FuturesUnordered,
  31    FutureExt, SinkExt, StreamExt, TryStreamExt,
  32};
  33use lazy_static::lazy_static;
  34use prometheus::{register_int_gauge, IntGauge};
  35use rpc::{
  36    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  37    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  38};
  39use serde::{Serialize, Serializer};
  40use std::{
  41    any::TypeId,
  42    fmt,
  43    future::Future,
  44    marker::PhantomData,
  45    net::SocketAddr,
  46    ops::{Deref, DerefMut},
  47    rc::Rc,
  48    sync::{
  49        atomic::{AtomicBool, Ordering::SeqCst},
  50        Arc,
  51    },
  52    time::Duration,
  53};
  54use tokio::{
  55    sync::{Mutex, MutexGuard},
  56    time::Sleep,
  57};
  58use tower::ServiceBuilder;
  59use tracing::{info_span, instrument, Instrument};
  60
  61lazy_static! {
  62    static ref METRIC_CONNECTIONS: IntGauge =
  63        register_int_gauge!("connections", "number of connections").unwrap();
  64    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  65        "shared_projects",
  66        "number of open projects with one or more guests"
  67    )
  68    .unwrap();
  69}
  70
  71type MessageHandler =
  72    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  73
  74struct Response<R> {
  75    peer: Arc<Peer>,
  76    receipt: Receipt<R>,
  77    responded: Arc<AtomicBool>,
  78}
  79
  80impl<R: RequestMessage> Response<R> {
  81    fn send(self, payload: R::Response) -> Result<()> {
  82        self.responded.store(true, SeqCst);
  83        self.peer.respond(self.receipt, payload)?;
  84        Ok(())
  85    }
  86}
  87
  88#[derive(Clone)]
  89struct Session {
  90    user_id: UserId,
  91    connection_id: ConnectionId,
  92    db: Arc<Mutex<DbHandle>>,
  93    peer: Arc<Peer>,
  94    connection_pool: Arc<Mutex<ConnectionPool>>,
  95    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
  96}
  97
  98impl Session {
  99    async fn db(&self) -> MutexGuard<DbHandle> {
 100        #[cfg(test)]
 101        tokio::task::yield_now().await;
 102        let guard = self.db.lock().await;
 103        #[cfg(test)]
 104        tokio::task::yield_now().await;
 105        guard
 106    }
 107
 108    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 109        #[cfg(test)]
 110        tokio::task::yield_now().await;
 111        let guard = self.connection_pool.lock().await;
 112        #[cfg(test)]
 113        tokio::task::yield_now().await;
 114        ConnectionPoolGuard {
 115            guard,
 116            _not_send: PhantomData,
 117        }
 118    }
 119}
 120
 121impl fmt::Debug for Session {
 122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 123        f.debug_struct("Session")
 124            .field("user_id", &self.user_id)
 125            .field("connection_id", &self.connection_id)
 126            .finish()
 127    }
 128}
 129
 130struct DbHandle(Arc<DefaultDb>);
 131
 132impl Deref for DbHandle {
 133    type Target = DefaultDb;
 134
 135    fn deref(&self) -> &Self::Target {
 136        self.0.as_ref()
 137    }
 138}
 139
 140pub struct Server {
 141    peer: Arc<Peer>,
 142    pub(crate) connection_pool: Arc<Mutex<ConnectionPool>>,
 143    app_state: Arc<AppState>,
 144    handlers: HashMap<TypeId, MessageHandler>,
 145}
 146
 147pub trait Executor: Send + Clone {
 148    type Sleep: Send + Future;
 149    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
 150    fn sleep(&self, duration: Duration) -> Self::Sleep;
 151}
 152
 153#[derive(Clone)]
 154pub struct RealExecutor;
 155
 156pub(crate) struct ConnectionPoolGuard<'a> {
 157    guard: MutexGuard<'a, ConnectionPool>,
 158    _not_send: PhantomData<Rc<()>>,
 159}
 160
 161#[derive(Serialize)]
 162pub struct ServerSnapshot<'a> {
 163    peer: &'a Peer,
 164    #[serde(serialize_with = "serialize_deref")]
 165    connection_pool: ConnectionPoolGuard<'a>,
 166}
 167
 168pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 169where
 170    S: Serializer,
 171    T: Deref<Target = U>,
 172    U: Serialize,
 173{
 174    Serialize::serialize(value.deref(), serializer)
 175}
 176
 177impl Server {
 178    pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
 179        let mut server = Self {
 180            peer: Peer::new(),
 181            app_state,
 182            connection_pool: Default::default(),
 183            handlers: Default::default(),
 184        };
 185
 186        server
 187            .add_request_handler(ping)
 188            .add_request_handler(create_room)
 189            .add_request_handler(join_room)
 190            .add_message_handler(leave_room)
 191            .add_request_handler(call)
 192            .add_request_handler(cancel_call)
 193            .add_message_handler(decline_call)
 194            .add_request_handler(update_participant_location)
 195            .add_request_handler(share_project)
 196            .add_message_handler(unshare_project)
 197            .add_request_handler(join_project)
 198            .add_message_handler(leave_project)
 199            .add_request_handler(update_project)
 200            .add_request_handler(update_worktree)
 201            .add_message_handler(start_language_server)
 202            .add_message_handler(update_language_server)
 203            .add_request_handler(update_diagnostic_summary)
 204            .add_request_handler(forward_project_request::<proto::GetHover>)
 205            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 206            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 207            .add_request_handler(forward_project_request::<proto::GetReferences>)
 208            .add_request_handler(forward_project_request::<proto::SearchProject>)
 209            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 210            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 211            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 212            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 213            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 214            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 215            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 216            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 217            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 218            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 219            .add_request_handler(forward_project_request::<proto::PerformRename>)
 220            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 221            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 222            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 223            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 224            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 225            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 226            .add_message_handler(create_buffer_for_peer)
 227            .add_request_handler(update_buffer)
 228            .add_message_handler(update_buffer_file)
 229            .add_message_handler(buffer_reloaded)
 230            .add_message_handler(buffer_saved)
 231            .add_request_handler(save_buffer)
 232            .add_request_handler(get_users)
 233            .add_request_handler(fuzzy_search_users)
 234            .add_request_handler(request_contact)
 235            .add_request_handler(remove_contact)
 236            .add_request_handler(respond_to_contact_request)
 237            .add_request_handler(follow)
 238            .add_message_handler(unfollow)
 239            .add_message_handler(update_followers)
 240            .add_message_handler(update_diff_base)
 241            .add_request_handler(get_private_user_info);
 242
 243        Arc::new(server)
 244    }
 245
 246    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 247    where
 248        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 249        Fut: 'static + Send + Future<Output = Result<()>>,
 250        M: EnvelopedMessage,
 251    {
 252        let prev_handler = self.handlers.insert(
 253            TypeId::of::<M>(),
 254            Box::new(move |envelope, session| {
 255                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 256                let span = info_span!(
 257                    "handle message",
 258                    payload_type = envelope.payload_type_name()
 259                );
 260                span.in_scope(|| {
 261                    tracing::info!(
 262                        payload_type = envelope.payload_type_name(),
 263                        "message received"
 264                    );
 265                });
 266                let future = (handler)(*envelope, session);
 267                async move {
 268                    if let Err(error) = future.await {
 269                        tracing::error!(%error, "error handling message");
 270                    }
 271                }
 272                .instrument(span)
 273                .boxed()
 274            }),
 275        );
 276        if prev_handler.is_some() {
 277            panic!("registered a handler for the same message twice");
 278        }
 279        self
 280    }
 281
 282    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 283    where
 284        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 285        Fut: 'static + Send + Future<Output = Result<()>>,
 286        M: EnvelopedMessage,
 287    {
 288        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 289        self
 290    }
 291
 292    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 293    where
 294        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 295        Fut: Send + Future<Output = Result<()>>,
 296        M: RequestMessage,
 297    {
 298        let handler = Arc::new(handler);
 299        self.add_handler(move |envelope, session| {
 300            let receipt = envelope.receipt();
 301            let handler = handler.clone();
 302            async move {
 303                let peer = session.peer.clone();
 304                let responded = Arc::new(AtomicBool::default());
 305                let response = Response {
 306                    peer: peer.clone(),
 307                    responded: responded.clone(),
 308                    receipt,
 309                };
 310                match (handler)(envelope.payload, response, session).await {
 311                    Ok(()) => {
 312                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 313                            Ok(())
 314                        } else {
 315                            Err(anyhow!("handler did not send a response"))?
 316                        }
 317                    }
 318                    Err(error) => {
 319                        peer.respond_with_error(
 320                            receipt,
 321                            proto::Error {
 322                                message: error.to_string(),
 323                            },
 324                        )?;
 325                        Err(error)
 326                    }
 327                }
 328            }
 329        })
 330    }
 331
 332    pub fn handle_connection<E: Executor>(
 333        self: &Arc<Self>,
 334        connection: Connection,
 335        address: String,
 336        user: User,
 337        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 338        executor: E,
 339    ) -> impl Future<Output = Result<()>> {
 340        let this = self.clone();
 341        let user_id = user.id;
 342        let login = user.github_login;
 343        let span = info_span!("handle connection", %user_id, %login, %address);
 344        async move {
 345            let (connection_id, handle_io, mut incoming_rx) = this
 346                .peer
 347                .add_connection(connection, {
 348                    let executor = executor.clone();
 349                    move |duration| {
 350                        let timer = executor.sleep(duration);
 351                        async move {
 352                            timer.await;
 353                        }
 354                    }
 355                });
 356
 357            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 358            this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
 359            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 360
 361            if let Some(send_connection_id) = send_connection_id.take() {
 362                let _ = send_connection_id.send(connection_id);
 363            }
 364
 365            if !user.connected_once {
 366                this.peer.send(connection_id, proto::ShowContacts {})?;
 367                this.app_state.db.set_user_connected_once(user_id, true).await?;
 368            }
 369
 370            let (contacts, invite_code) = future::try_join(
 371                this.app_state.db.get_contacts(user_id),
 372                this.app_state.db.get_invite_code_for_user(user_id)
 373            ).await?;
 374
 375            {
 376                let mut pool = this.connection_pool.lock().await;
 377                pool.add_connection(connection_id, user_id, user.admin);
 378                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 379
 380                if let Some((code, count)) = invite_code {
 381                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 382                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 383                        count,
 384                    })?;
 385                }
 386            }
 387
 388            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 389                this.peer.send(connection_id, incoming_call)?;
 390            }
 391
 392            let session = Session {
 393                user_id,
 394                connection_id,
 395                db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))),
 396                peer: this.peer.clone(),
 397                connection_pool: this.connection_pool.clone(),
 398                live_kit_client: this.app_state.live_kit_client.clone()
 399            };
 400            update_user_contacts(user_id, &session).await?;
 401
 402            let handle_io = handle_io.fuse();
 403            futures::pin_mut!(handle_io);
 404
 405            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 406            // This prevents deadlocks when e.g., client A performs a request to client B and
 407            // client B performs a request to client A. If both clients stop processing further
 408            // messages until their respective request completes, they won't have a chance to
 409            // respond to the other client's request and cause a deadlock.
 410            //
 411            // This arrangement ensures we will attempt to process earlier messages first, but fall
 412            // back to processing messages arrived later in the spirit of making progress.
 413            let mut foreground_message_handlers = FuturesUnordered::new();
 414            loop {
 415                let next_message = incoming_rx.next().fuse();
 416                futures::pin_mut!(next_message);
 417                futures::select_biased! {
 418                    result = handle_io => {
 419                        if let Err(error) = result {
 420                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 421                        }
 422                        break;
 423                    }
 424                    _ = foreground_message_handlers.next() => {}
 425                    message = next_message => {
 426                        if let Some(message) = message {
 427                            let type_name = message.payload_type_name();
 428                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 429                            let span_enter = span.enter();
 430                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 431                                let is_background = message.is_background();
 432                                let handle_message = (handler)(message, session.clone());
 433                                drop(span_enter);
 434
 435                                let handle_message = handle_message.instrument(span);
 436                                if is_background {
 437                                    executor.spawn_detached(handle_message);
 438                                } else {
 439                                    foreground_message_handlers.push(handle_message);
 440                                }
 441                            } else {
 442                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 443                            }
 444                        } else {
 445                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 446                            break;
 447                        }
 448                    }
 449                }
 450            }
 451
 452            drop(foreground_message_handlers);
 453            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 454            if let Err(error) = sign_out(session).await {
 455                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 456            }
 457
 458            Ok(())
 459        }.instrument(span)
 460    }
 461
 462    pub async fn invite_code_redeemed(
 463        self: &Arc<Self>,
 464        inviter_id: UserId,
 465        invitee_id: UserId,
 466    ) -> Result<()> {
 467        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 468            if let Some(code) = &user.invite_code {
 469                let pool = self.connection_pool.lock().await;
 470                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 471                for connection_id in pool.user_connection_ids(inviter_id) {
 472                    self.peer.send(
 473                        connection_id,
 474                        proto::UpdateContacts {
 475                            contacts: vec![invitee_contact.clone()],
 476                            ..Default::default()
 477                        },
 478                    )?;
 479                    self.peer.send(
 480                        connection_id,
 481                        proto::UpdateInviteInfo {
 482                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 483                            count: user.invite_count as u32,
 484                        },
 485                    )?;
 486                }
 487            }
 488        }
 489        Ok(())
 490    }
 491
 492    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 493        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 494            if let Some(invite_code) = &user.invite_code {
 495                let pool = self.connection_pool.lock().await;
 496                for connection_id in pool.user_connection_ids(user_id) {
 497                    self.peer.send(
 498                        connection_id,
 499                        proto::UpdateInviteInfo {
 500                            url: format!(
 501                                "{}{}",
 502                                self.app_state.config.invite_link_prefix, invite_code
 503                            ),
 504                            count: user.invite_count as u32,
 505                        },
 506                    )?;
 507                }
 508            }
 509        }
 510        Ok(())
 511    }
 512
 513    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 514        ServerSnapshot {
 515            connection_pool: ConnectionPoolGuard {
 516                guard: self.connection_pool.lock().await,
 517                _not_send: PhantomData,
 518            },
 519            peer: &self.peer,
 520        }
 521    }
 522}
 523
 524impl<'a> Deref for ConnectionPoolGuard<'a> {
 525    type Target = ConnectionPool;
 526
 527    fn deref(&self) -> &Self::Target {
 528        &*self.guard
 529    }
 530}
 531
 532impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 533    fn deref_mut(&mut self) -> &mut Self::Target {
 534        &mut *self.guard
 535    }
 536}
 537
 538impl<'a> Drop for ConnectionPoolGuard<'a> {
 539    fn drop(&mut self) {
 540        #[cfg(test)]
 541        self.check_invariants();
 542    }
 543}
 544
 545impl Executor for RealExecutor {
 546    type Sleep = Sleep;
 547
 548    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
 549        tokio::task::spawn(future);
 550    }
 551
 552    fn sleep(&self, duration: Duration) -> Self::Sleep {
 553        tokio::time::sleep(duration)
 554    }
 555}
 556
 557fn broadcast<F>(
 558    sender_id: ConnectionId,
 559    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 560    mut f: F,
 561) where
 562    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 563{
 564    for receiver_id in receiver_ids {
 565        if receiver_id != sender_id {
 566            f(receiver_id).trace_err();
 567        }
 568    }
 569}
 570
 571lazy_static! {
 572    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 573}
 574
 575pub struct ProtocolVersion(u32);
 576
 577impl Header for ProtocolVersion {
 578    fn name() -> &'static HeaderName {
 579        &ZED_PROTOCOL_VERSION
 580    }
 581
 582    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 583    where
 584        Self: Sized,
 585        I: Iterator<Item = &'i axum::http::HeaderValue>,
 586    {
 587        let version = values
 588            .next()
 589            .ok_or_else(axum::headers::Error::invalid)?
 590            .to_str()
 591            .map_err(|_| axum::headers::Error::invalid())?
 592            .parse()
 593            .map_err(|_| axum::headers::Error::invalid())?;
 594        Ok(Self(version))
 595    }
 596
 597    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 598        values.extend([self.0.to_string().parse().unwrap()]);
 599    }
 600}
 601
 602pub fn routes(server: Arc<Server>) -> Router<Body> {
 603    Router::new()
 604        .route("/rpc", get(handle_websocket_request))
 605        .layer(
 606            ServiceBuilder::new()
 607                .layer(Extension(server.app_state.clone()))
 608                .layer(middleware::from_fn(auth::validate_header)),
 609        )
 610        .route("/metrics", get(handle_metrics))
 611        .layer(Extension(server))
 612}
 613
 614pub async fn handle_websocket_request(
 615    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 616    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 617    Extension(server): Extension<Arc<Server>>,
 618    Extension(user): Extension<User>,
 619    ws: WebSocketUpgrade,
 620) -> axum::response::Response {
 621    if protocol_version != rpc::PROTOCOL_VERSION {
 622        return (
 623            StatusCode::UPGRADE_REQUIRED,
 624            "client must be upgraded".to_string(),
 625        )
 626            .into_response();
 627    }
 628    let socket_address = socket_address.to_string();
 629    ws.on_upgrade(move |socket| {
 630        use util::ResultExt;
 631        let socket = socket
 632            .map_ok(to_tungstenite_message)
 633            .err_into()
 634            .with(|message| async move { Ok(to_axum_message(message)) });
 635        let connection = Connection::new(Box::pin(socket));
 636        async move {
 637            server
 638                .handle_connection(connection, socket_address, user, None, RealExecutor)
 639                .await
 640                .log_err();
 641        }
 642    })
 643}
 644
 645pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 646    let connections = server
 647        .connection_pool
 648        .lock()
 649        .await
 650        .connections()
 651        .filter(|connection| !connection.admin)
 652        .count();
 653
 654    METRIC_CONNECTIONS.set(connections as _);
 655
 656    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 657    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 658
 659    let encoder = prometheus::TextEncoder::new();
 660    let metric_families = prometheus::gather();
 661    let encoded_metrics = encoder
 662        .encode_to_string(&metric_families)
 663        .map_err(|err| anyhow!("{}", err))?;
 664    Ok(encoded_metrics)
 665}
 666
 667#[instrument(err)]
 668async fn sign_out(session: Session) -> Result<()> {
 669    session.peer.disconnect(session.connection_id);
 670    let decline_calls = {
 671        let mut pool = session.connection_pool().await;
 672        pool.remove_connection(session.connection_id)?;
 673        let mut connections = pool.user_connection_ids(session.user_id);
 674        connections.next().is_none()
 675    };
 676
 677    leave_room_for_session(&session).await.trace_err();
 678    if decline_calls {
 679        if let Some(room) = session
 680            .db()
 681            .await
 682            .decline_call(None, session.user_id)
 683            .await
 684            .trace_err()
 685        {
 686            room_updated(&room, &session);
 687        }
 688    }
 689
 690    update_user_contacts(session.user_id, &session).await?;
 691
 692    Ok(())
 693}
 694
 695async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 696    response.send(proto::Ack {})?;
 697    Ok(())
 698}
 699
 700async fn create_room(
 701    _request: proto::CreateRoom,
 702    response: Response<proto::CreateRoom>,
 703    session: Session,
 704) -> Result<()> {
 705    let room = session
 706        .db()
 707        .await
 708        .create_room(session.user_id, session.connection_id)
 709        .await?;
 710
 711    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 712        if let Some(_) = live_kit
 713            .create_room(room.live_kit_room.clone())
 714            .await
 715            .trace_err()
 716        {
 717            if let Some(token) = live_kit
 718                .room_token(&room.live_kit_room, &session.connection_id.to_string())
 719                .trace_err()
 720            {
 721                Some(proto::LiveKitConnectionInfo {
 722                    server_url: live_kit.url().into(),
 723                    token,
 724                })
 725            } else {
 726                None
 727            }
 728        } else {
 729            None
 730        }
 731    } else {
 732        None
 733    };
 734
 735    response.send(proto::CreateRoomResponse {
 736        room: Some(room),
 737        live_kit_connection_info,
 738    })?;
 739    update_user_contacts(session.user_id, &session).await?;
 740    Ok(())
 741}
 742
 743async fn join_room(
 744    request: proto::JoinRoom,
 745    response: Response<proto::JoinRoom>,
 746    session: Session,
 747) -> Result<()> {
 748    let room = session
 749        .db()
 750        .await
 751        .join_room(
 752            RoomId::from_proto(request.id),
 753            session.user_id,
 754            session.connection_id,
 755        )
 756        .await?;
 757    for connection_id in session
 758        .connection_pool()
 759        .await
 760        .user_connection_ids(session.user_id)
 761    {
 762        session
 763            .peer
 764            .send(connection_id, proto::CallCanceled {})
 765            .trace_err();
 766    }
 767
 768    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 769        if let Some(token) = live_kit
 770            .room_token(&room.live_kit_room, &session.connection_id.to_string())
 771            .trace_err()
 772        {
 773            Some(proto::LiveKitConnectionInfo {
 774                server_url: live_kit.url().into(),
 775                token,
 776            })
 777        } else {
 778            None
 779        }
 780    } else {
 781        None
 782    };
 783
 784    room_updated(&room, &session);
 785    response.send(proto::JoinRoomResponse {
 786        room: Some(room),
 787        live_kit_connection_info,
 788    })?;
 789
 790    update_user_contacts(session.user_id, &session).await?;
 791    Ok(())
 792}
 793
 794async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
 795    leave_room_for_session(&session).await
 796}
 797
 798async fn call(
 799    request: proto::Call,
 800    response: Response<proto::Call>,
 801    session: Session,
 802) -> Result<()> {
 803    let room_id = RoomId::from_proto(request.room_id);
 804    let calling_user_id = session.user_id;
 805    let calling_connection_id = session.connection_id;
 806    let called_user_id = UserId::from_proto(request.called_user_id);
 807    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
 808    if !session
 809        .db()
 810        .await
 811        .has_contact(calling_user_id, called_user_id)
 812        .await?
 813    {
 814        return Err(anyhow!("cannot call a user who isn't a contact"))?;
 815    }
 816
 817    let (room, incoming_call) = session
 818        .db()
 819        .await
 820        .call(
 821            room_id,
 822            calling_user_id,
 823            calling_connection_id,
 824            called_user_id,
 825            initial_project_id,
 826        )
 827        .await?;
 828    room_updated(&room, &session);
 829    update_user_contacts(called_user_id, &session).await?;
 830
 831    let mut calls = session
 832        .connection_pool()
 833        .await
 834        .user_connection_ids(called_user_id)
 835        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
 836        .collect::<FuturesUnordered<_>>();
 837
 838    while let Some(call_response) = calls.next().await {
 839        match call_response.as_ref() {
 840            Ok(_) => {
 841                response.send(proto::Ack {})?;
 842                return Ok(());
 843            }
 844            Err(_) => {
 845                call_response.trace_err();
 846            }
 847        }
 848    }
 849
 850    let room = session
 851        .db()
 852        .await
 853        .call_failed(room_id, called_user_id)
 854        .await?;
 855    room_updated(&room, &session);
 856    update_user_contacts(called_user_id, &session).await?;
 857
 858    Err(anyhow!("failed to ring user"))?
 859}
 860
 861async fn cancel_call(
 862    request: proto::CancelCall,
 863    response: Response<proto::CancelCall>,
 864    session: Session,
 865) -> Result<()> {
 866    let called_user_id = UserId::from_proto(request.called_user_id);
 867    let room_id = RoomId::from_proto(request.room_id);
 868    let room = session
 869        .db()
 870        .await
 871        .cancel_call(Some(room_id), session.connection_id, called_user_id)
 872        .await?;
 873    for connection_id in session
 874        .connection_pool()
 875        .await
 876        .user_connection_ids(called_user_id)
 877    {
 878        session
 879            .peer
 880            .send(connection_id, proto::CallCanceled {})
 881            .trace_err();
 882    }
 883    room_updated(&room, &session);
 884    response.send(proto::Ack {})?;
 885
 886    update_user_contacts(called_user_id, &session).await?;
 887    Ok(())
 888}
 889
 890async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
 891    let room_id = RoomId::from_proto(message.room_id);
 892    let room = session
 893        .db()
 894        .await
 895        .decline_call(Some(room_id), session.user_id)
 896        .await?;
 897    for connection_id in session
 898        .connection_pool()
 899        .await
 900        .user_connection_ids(session.user_id)
 901    {
 902        session
 903            .peer
 904            .send(connection_id, proto::CallCanceled {})
 905            .trace_err();
 906    }
 907    room_updated(&room, &session);
 908    update_user_contacts(session.user_id, &session).await?;
 909    Ok(())
 910}
 911
 912async fn update_participant_location(
 913    request: proto::UpdateParticipantLocation,
 914    response: Response<proto::UpdateParticipantLocation>,
 915    session: Session,
 916) -> Result<()> {
 917    let room_id = RoomId::from_proto(request.room_id);
 918    let location = request
 919        .location
 920        .ok_or_else(|| anyhow!("invalid location"))?;
 921    let room = session
 922        .db()
 923        .await
 924        .update_room_participant_location(room_id, session.connection_id, location)
 925        .await?;
 926    room_updated(&room, &session);
 927    response.send(proto::Ack {})?;
 928    Ok(())
 929}
 930
 931async fn share_project(
 932    request: proto::ShareProject,
 933    response: Response<proto::ShareProject>,
 934    session: Session,
 935) -> Result<()> {
 936    let (project_id, room) = session
 937        .db()
 938        .await
 939        .share_project(
 940            RoomId::from_proto(request.room_id),
 941            session.connection_id,
 942            &request.worktrees,
 943        )
 944        .await?;
 945    response.send(proto::ShareProjectResponse {
 946        project_id: project_id.to_proto(),
 947    })?;
 948    room_updated(&room, &session);
 949
 950    Ok(())
 951}
 952
 953async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
 954    let project_id = ProjectId::from_proto(message.project_id);
 955
 956    let (room, guest_connection_ids) = session
 957        .db()
 958        .await
 959        .unshare_project(project_id, session.connection_id)
 960        .await?;
 961
 962    broadcast(session.connection_id, guest_connection_ids, |conn_id| {
 963        session.peer.send(conn_id, message.clone())
 964    });
 965    room_updated(&room, &session);
 966
 967    Ok(())
 968}
 969
 970async fn join_project(
 971    request: proto::JoinProject,
 972    response: Response<proto::JoinProject>,
 973    session: Session,
 974) -> Result<()> {
 975    let project_id = ProjectId::from_proto(request.project_id);
 976    let guest_user_id = session.user_id;
 977
 978    tracing::info!(%project_id, "join project");
 979
 980    let (project, replica_id) = session
 981        .db()
 982        .await
 983        .join_project(project_id, session.connection_id)
 984        .await?;
 985
 986    let collaborators = project
 987        .collaborators
 988        .iter()
 989        .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
 990        .map(|collaborator| proto::Collaborator {
 991            peer_id: collaborator.connection_id as u32,
 992            replica_id: collaborator.replica_id.0 as u32,
 993            user_id: collaborator.user_id.to_proto(),
 994        })
 995        .collect::<Vec<_>>();
 996    let worktrees = project
 997        .worktrees
 998        .iter()
 999        .map(|(id, worktree)| proto::WorktreeMetadata {
1000            id: id.to_proto(),
1001            root_name: worktree.root_name.clone(),
1002            visible: worktree.visible,
1003            abs_path: worktree.abs_path.clone(),
1004        })
1005        .collect::<Vec<_>>();
1006
1007    for collaborator in &collaborators {
1008        session
1009            .peer
1010            .send(
1011                ConnectionId(collaborator.peer_id),
1012                proto::AddProjectCollaborator {
1013                    project_id: project_id.to_proto(),
1014                    collaborator: Some(proto::Collaborator {
1015                        peer_id: session.connection_id.0,
1016                        replica_id: replica_id.0 as u32,
1017                        user_id: guest_user_id.to_proto(),
1018                    }),
1019                },
1020            )
1021            .trace_err();
1022    }
1023
1024    // First, we send the metadata associated with each worktree.
1025    response.send(proto::JoinProjectResponse {
1026        worktrees: worktrees.clone(),
1027        replica_id: replica_id.0 as u32,
1028        collaborators: collaborators.clone(),
1029        language_servers: project.language_servers.clone(),
1030    })?;
1031
1032    for (worktree_id, worktree) in project.worktrees {
1033        #[cfg(any(test, feature = "test-support"))]
1034        const MAX_CHUNK_SIZE: usize = 2;
1035        #[cfg(not(any(test, feature = "test-support")))]
1036        const MAX_CHUNK_SIZE: usize = 256;
1037
1038        // Stream this worktree's entries.
1039        let message = proto::UpdateWorktree {
1040            project_id: project_id.to_proto(),
1041            worktree_id: worktree_id.to_proto(),
1042            abs_path: worktree.abs_path.clone(),
1043            root_name: worktree.root_name,
1044            updated_entries: worktree.entries,
1045            removed_entries: Default::default(),
1046            scan_id: worktree.scan_id,
1047            is_last_update: worktree.is_complete,
1048        };
1049        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1050            session.peer.send(session.connection_id, update.clone())?;
1051        }
1052
1053        // Stream this worktree's diagnostics.
1054        for summary in worktree.diagnostic_summaries {
1055            session.peer.send(
1056                session.connection_id,
1057                proto::UpdateDiagnosticSummary {
1058                    project_id: project_id.to_proto(),
1059                    worktree_id: worktree.id.to_proto(),
1060                    summary: Some(summary),
1061                },
1062            )?;
1063        }
1064    }
1065
1066    for language_server in &project.language_servers {
1067        session.peer.send(
1068            session.connection_id,
1069            proto::UpdateLanguageServer {
1070                project_id: project_id.to_proto(),
1071                language_server_id: language_server.id,
1072                variant: Some(
1073                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1074                        proto::LspDiskBasedDiagnosticsUpdated {},
1075                    ),
1076                ),
1077            },
1078        )?;
1079    }
1080
1081    Ok(())
1082}
1083
1084async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1085    let sender_id = session.connection_id;
1086    let project_id = ProjectId::from_proto(request.project_id);
1087    let project;
1088    {
1089        project = session
1090            .db()
1091            .await
1092            .leave_project(project_id, sender_id)
1093            .await?;
1094        tracing::info!(
1095            %project_id,
1096            host_user_id = %project.host_user_id,
1097            host_connection_id = %project.host_connection_id,
1098            "leave project"
1099        );
1100
1101        broadcast(sender_id, project.connection_ids, |conn_id| {
1102            session.peer.send(
1103                conn_id,
1104                proto::RemoveProjectCollaborator {
1105                    project_id: project_id.to_proto(),
1106                    peer_id: sender_id.0,
1107                },
1108            )
1109        });
1110    }
1111
1112    Ok(())
1113}
1114
1115async fn update_project(
1116    request: proto::UpdateProject,
1117    response: Response<proto::UpdateProject>,
1118    session: Session,
1119) -> Result<()> {
1120    let project_id = ProjectId::from_proto(request.project_id);
1121    let (room, guest_connection_ids) = session
1122        .db()
1123        .await
1124        .update_project(project_id, session.connection_id, &request.worktrees)
1125        .await?;
1126    broadcast(
1127        session.connection_id,
1128        guest_connection_ids,
1129        |connection_id| {
1130            session
1131                .peer
1132                .forward_send(session.connection_id, connection_id, request.clone())
1133        },
1134    );
1135    room_updated(&room, &session);
1136    response.send(proto::Ack {})?;
1137
1138    Ok(())
1139}
1140
1141async fn update_worktree(
1142    request: proto::UpdateWorktree,
1143    response: Response<proto::UpdateWorktree>,
1144    session: Session,
1145) -> Result<()> {
1146    let guest_connection_ids = session
1147        .db()
1148        .await
1149        .update_worktree(&request, session.connection_id)
1150        .await?;
1151
1152    broadcast(
1153        session.connection_id,
1154        guest_connection_ids,
1155        |connection_id| {
1156            session
1157                .peer
1158                .forward_send(session.connection_id, connection_id, request.clone())
1159        },
1160    );
1161    response.send(proto::Ack {})?;
1162    Ok(())
1163}
1164
1165async fn update_diagnostic_summary(
1166    request: proto::UpdateDiagnosticSummary,
1167    response: Response<proto::UpdateDiagnosticSummary>,
1168    session: Session,
1169) -> Result<()> {
1170    let guest_connection_ids = session
1171        .db()
1172        .await
1173        .update_diagnostic_summary(&request, session.connection_id)
1174        .await?;
1175
1176    broadcast(
1177        session.connection_id,
1178        guest_connection_ids,
1179        |connection_id| {
1180            session
1181                .peer
1182                .forward_send(session.connection_id, connection_id, request.clone())
1183        },
1184    );
1185
1186    response.send(proto::Ack {})?;
1187    Ok(())
1188}
1189
1190async fn start_language_server(
1191    request: proto::StartLanguageServer,
1192    session: Session,
1193) -> Result<()> {
1194    let guest_connection_ids = session
1195        .db()
1196        .await
1197        .start_language_server(&request, session.connection_id)
1198        .await?;
1199
1200    broadcast(
1201        session.connection_id,
1202        guest_connection_ids,
1203        |connection_id| {
1204            session
1205                .peer
1206                .forward_send(session.connection_id, connection_id, request.clone())
1207        },
1208    );
1209    Ok(())
1210}
1211
1212async fn update_language_server(
1213    request: proto::UpdateLanguageServer,
1214    session: Session,
1215) -> Result<()> {
1216    let project_id = ProjectId::from_proto(request.project_id);
1217    let project_connection_ids = session
1218        .db()
1219        .await
1220        .project_connection_ids(project_id, session.connection_id)
1221        .await?;
1222    broadcast(
1223        session.connection_id,
1224        project_connection_ids,
1225        |connection_id| {
1226            session
1227                .peer
1228                .forward_send(session.connection_id, connection_id, request.clone())
1229        },
1230    );
1231    Ok(())
1232}
1233
1234async fn forward_project_request<T>(
1235    request: T,
1236    response: Response<T>,
1237    session: Session,
1238) -> Result<()>
1239where
1240    T: EntityMessage + RequestMessage,
1241{
1242    let project_id = ProjectId::from_proto(request.remote_entity_id());
1243    let collaborators = session
1244        .db()
1245        .await
1246        .project_collaborators(project_id, session.connection_id)
1247        .await?;
1248    let host = collaborators
1249        .iter()
1250        .find(|collaborator| collaborator.is_host)
1251        .ok_or_else(|| anyhow!("host not found"))?;
1252
1253    let payload = session
1254        .peer
1255        .forward_request(
1256            session.connection_id,
1257            ConnectionId(host.connection_id as u32),
1258            request,
1259        )
1260        .await?;
1261
1262    response.send(payload)?;
1263    Ok(())
1264}
1265
1266async fn save_buffer(
1267    request: proto::SaveBuffer,
1268    response: Response<proto::SaveBuffer>,
1269    session: Session,
1270) -> Result<()> {
1271    let project_id = ProjectId::from_proto(request.project_id);
1272    let collaborators = session
1273        .db()
1274        .await
1275        .project_collaborators(project_id, session.connection_id)
1276        .await?;
1277    let host = collaborators
1278        .into_iter()
1279        .find(|collaborator| collaborator.is_host)
1280        .ok_or_else(|| anyhow!("host not found"))?;
1281    let host_connection_id = ConnectionId(host.connection_id as u32);
1282    let response_payload = session
1283        .peer
1284        .forward_request(session.connection_id, host_connection_id, request.clone())
1285        .await?;
1286
1287    let mut collaborators = session
1288        .db()
1289        .await
1290        .project_collaborators(project_id, session.connection_id)
1291        .await?;
1292    collaborators
1293        .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1294    let project_connection_ids = collaborators
1295        .into_iter()
1296        .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1297    broadcast(host_connection_id, project_connection_ids, |conn_id| {
1298        session
1299            .peer
1300            .forward_send(host_connection_id, conn_id, response_payload.clone())
1301    });
1302    response.send(response_payload)?;
1303    Ok(())
1304}
1305
1306async fn create_buffer_for_peer(
1307    request: proto::CreateBufferForPeer,
1308    session: Session,
1309) -> Result<()> {
1310    session.peer.forward_send(
1311        session.connection_id,
1312        ConnectionId(request.peer_id),
1313        request,
1314    )?;
1315    Ok(())
1316}
1317
1318async fn update_buffer(
1319    request: proto::UpdateBuffer,
1320    response: Response<proto::UpdateBuffer>,
1321    session: Session,
1322) -> Result<()> {
1323    let project_id = ProjectId::from_proto(request.project_id);
1324    let project_connection_ids = session
1325        .db()
1326        .await
1327        .project_connection_ids(project_id, session.connection_id)
1328        .await?;
1329
1330    broadcast(
1331        session.connection_id,
1332        project_connection_ids,
1333        |connection_id| {
1334            session
1335                .peer
1336                .forward_send(session.connection_id, connection_id, request.clone())
1337        },
1338    );
1339    response.send(proto::Ack {})?;
1340    Ok(())
1341}
1342
1343async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1344    let project_id = ProjectId::from_proto(request.project_id);
1345    let project_connection_ids = session
1346        .db()
1347        .await
1348        .project_connection_ids(project_id, session.connection_id)
1349        .await?;
1350
1351    broadcast(
1352        session.connection_id,
1353        project_connection_ids,
1354        |connection_id| {
1355            session
1356                .peer
1357                .forward_send(session.connection_id, connection_id, request.clone())
1358        },
1359    );
1360    Ok(())
1361}
1362
1363async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1364    let project_id = ProjectId::from_proto(request.project_id);
1365    let project_connection_ids = session
1366        .db()
1367        .await
1368        .project_connection_ids(project_id, session.connection_id)
1369        .await?;
1370    broadcast(
1371        session.connection_id,
1372        project_connection_ids,
1373        |connection_id| {
1374            session
1375                .peer
1376                .forward_send(session.connection_id, connection_id, request.clone())
1377        },
1378    );
1379    Ok(())
1380}
1381
1382async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1383    let project_id = ProjectId::from_proto(request.project_id);
1384    let project_connection_ids = session
1385        .db()
1386        .await
1387        .project_connection_ids(project_id, session.connection_id)
1388        .await?;
1389    broadcast(
1390        session.connection_id,
1391        project_connection_ids,
1392        |connection_id| {
1393            session
1394                .peer
1395                .forward_send(session.connection_id, connection_id, request.clone())
1396        },
1397    );
1398    Ok(())
1399}
1400
1401async fn follow(
1402    request: proto::Follow,
1403    response: Response<proto::Follow>,
1404    session: Session,
1405) -> Result<()> {
1406    let project_id = ProjectId::from_proto(request.project_id);
1407    let leader_id = ConnectionId(request.leader_id);
1408    let follower_id = session.connection_id;
1409    let project_connection_ids = session
1410        .db()
1411        .await
1412        .project_connection_ids(project_id, session.connection_id)
1413        .await?;
1414
1415    if !project_connection_ids.contains(&leader_id) {
1416        Err(anyhow!("no such peer"))?;
1417    }
1418
1419    let mut response_payload = session
1420        .peer
1421        .forward_request(session.connection_id, leader_id, request)
1422        .await?;
1423    response_payload
1424        .views
1425        .retain(|view| view.leader_id != Some(follower_id.0));
1426    response.send(response_payload)?;
1427    Ok(())
1428}
1429
1430async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1431    let project_id = ProjectId::from_proto(request.project_id);
1432    let leader_id = ConnectionId(request.leader_id);
1433    let project_connection_ids = session
1434        .db()
1435        .await
1436        .project_connection_ids(project_id, session.connection_id)
1437        .await?;
1438    if !project_connection_ids.contains(&leader_id) {
1439        Err(anyhow!("no such peer"))?;
1440    }
1441    session
1442        .peer
1443        .forward_send(session.connection_id, leader_id, request)?;
1444    Ok(())
1445}
1446
1447async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1448    let project_id = ProjectId::from_proto(request.project_id);
1449    let project_connection_ids = session
1450        .db
1451        .lock()
1452        .await
1453        .project_connection_ids(project_id, session.connection_id)
1454        .await?;
1455
1456    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1457        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1458        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1459        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1460    });
1461    for follower_id in &request.follower_ids {
1462        let follower_id = ConnectionId(*follower_id);
1463        if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1464            session
1465                .peer
1466                .forward_send(session.connection_id, follower_id, request.clone())?;
1467        }
1468    }
1469    Ok(())
1470}
1471
1472async fn get_users(
1473    request: proto::GetUsers,
1474    response: Response<proto::GetUsers>,
1475    session: Session,
1476) -> Result<()> {
1477    let user_ids = request
1478        .user_ids
1479        .into_iter()
1480        .map(UserId::from_proto)
1481        .collect();
1482    let users = session
1483        .db()
1484        .await
1485        .get_users_by_ids(user_ids)
1486        .await?
1487        .into_iter()
1488        .map(|user| proto::User {
1489            id: user.id.to_proto(),
1490            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1491            github_login: user.github_login,
1492        })
1493        .collect();
1494    response.send(proto::UsersResponse { users })?;
1495    Ok(())
1496}
1497
1498async fn fuzzy_search_users(
1499    request: proto::FuzzySearchUsers,
1500    response: Response<proto::FuzzySearchUsers>,
1501    session: Session,
1502) -> Result<()> {
1503    let query = request.query;
1504    let users = match query.len() {
1505        0 => vec![],
1506        1 | 2 => session
1507            .db()
1508            .await
1509            .get_user_by_github_account(&query, None)
1510            .await?
1511            .into_iter()
1512            .collect(),
1513        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1514    };
1515    let users = users
1516        .into_iter()
1517        .filter(|user| user.id != session.user_id)
1518        .map(|user| proto::User {
1519            id: user.id.to_proto(),
1520            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1521            github_login: user.github_login,
1522        })
1523        .collect();
1524    response.send(proto::UsersResponse { users })?;
1525    Ok(())
1526}
1527
1528async fn request_contact(
1529    request: proto::RequestContact,
1530    response: Response<proto::RequestContact>,
1531    session: Session,
1532) -> Result<()> {
1533    let requester_id = session.user_id;
1534    let responder_id = UserId::from_proto(request.responder_id);
1535    if requester_id == responder_id {
1536        return Err(anyhow!("cannot add yourself as a contact"))?;
1537    }
1538
1539    session
1540        .db()
1541        .await
1542        .send_contact_request(requester_id, responder_id)
1543        .await?;
1544
1545    // Update outgoing contact requests of requester
1546    let mut update = proto::UpdateContacts::default();
1547    update.outgoing_requests.push(responder_id.to_proto());
1548    for connection_id in session
1549        .connection_pool()
1550        .await
1551        .user_connection_ids(requester_id)
1552    {
1553        session.peer.send(connection_id, update.clone())?;
1554    }
1555
1556    // Update incoming contact requests of responder
1557    let mut update = proto::UpdateContacts::default();
1558    update
1559        .incoming_requests
1560        .push(proto::IncomingContactRequest {
1561            requester_id: requester_id.to_proto(),
1562            should_notify: true,
1563        });
1564    for connection_id in session
1565        .connection_pool()
1566        .await
1567        .user_connection_ids(responder_id)
1568    {
1569        session.peer.send(connection_id, update.clone())?;
1570    }
1571
1572    response.send(proto::Ack {})?;
1573    Ok(())
1574}
1575
1576async fn respond_to_contact_request(
1577    request: proto::RespondToContactRequest,
1578    response: Response<proto::RespondToContactRequest>,
1579    session: Session,
1580) -> Result<()> {
1581    let responder_id = session.user_id;
1582    let requester_id = UserId::from_proto(request.requester_id);
1583    let db = session.db().await;
1584    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1585        db.dismiss_contact_notification(responder_id, requester_id)
1586            .await?;
1587    } else {
1588        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1589
1590        db.respond_to_contact_request(responder_id, requester_id, accept)
1591            .await?;
1592        let busy = db.is_user_busy(requester_id).await?;
1593
1594        let pool = session.connection_pool().await;
1595        // Update responder with new contact
1596        let mut update = proto::UpdateContacts::default();
1597        if accept {
1598            update
1599                .contacts
1600                .push(contact_for_user(requester_id, false, busy, &pool));
1601        }
1602        update
1603            .remove_incoming_requests
1604            .push(requester_id.to_proto());
1605        for connection_id in pool.user_connection_ids(responder_id) {
1606            session.peer.send(connection_id, update.clone())?;
1607        }
1608
1609        // Update requester with new contact
1610        let mut update = proto::UpdateContacts::default();
1611        if accept {
1612            update
1613                .contacts
1614                .push(contact_for_user(responder_id, true, busy, &pool));
1615        }
1616        update
1617            .remove_outgoing_requests
1618            .push(responder_id.to_proto());
1619        for connection_id in pool.user_connection_ids(requester_id) {
1620            session.peer.send(connection_id, update.clone())?;
1621        }
1622    }
1623
1624    response.send(proto::Ack {})?;
1625    Ok(())
1626}
1627
1628async fn remove_contact(
1629    request: proto::RemoveContact,
1630    response: Response<proto::RemoveContact>,
1631    session: Session,
1632) -> Result<()> {
1633    let requester_id = session.user_id;
1634    let responder_id = UserId::from_proto(request.user_id);
1635    let db = session.db().await;
1636    db.remove_contact(requester_id, responder_id).await?;
1637
1638    let pool = session.connection_pool().await;
1639    // Update outgoing contact requests of requester
1640    let mut update = proto::UpdateContacts::default();
1641    update
1642        .remove_outgoing_requests
1643        .push(responder_id.to_proto());
1644    for connection_id in pool.user_connection_ids(requester_id) {
1645        session.peer.send(connection_id, update.clone())?;
1646    }
1647
1648    // Update incoming contact requests of responder
1649    let mut update = proto::UpdateContacts::default();
1650    update
1651        .remove_incoming_requests
1652        .push(requester_id.to_proto());
1653    for connection_id in pool.user_connection_ids(responder_id) {
1654        session.peer.send(connection_id, update.clone())?;
1655    }
1656
1657    response.send(proto::Ack {})?;
1658    Ok(())
1659}
1660
1661async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1662    let project_id = ProjectId::from_proto(request.project_id);
1663    let project_connection_ids = session
1664        .db()
1665        .await
1666        .project_connection_ids(project_id, session.connection_id)
1667        .await?;
1668    broadcast(
1669        session.connection_id,
1670        project_connection_ids,
1671        |connection_id| {
1672            session
1673                .peer
1674                .forward_send(session.connection_id, connection_id, request.clone())
1675        },
1676    );
1677    Ok(())
1678}
1679
1680async fn get_private_user_info(
1681    _request: proto::GetPrivateUserInfo,
1682    response: Response<proto::GetPrivateUserInfo>,
1683    session: Session,
1684) -> Result<()> {
1685    let metrics_id = session
1686        .db()
1687        .await
1688        .get_user_metrics_id(session.user_id)
1689        .await?;
1690    let user = session
1691        .db()
1692        .await
1693        .get_user_by_id(session.user_id)
1694        .await?
1695        .ok_or_else(|| anyhow!("user not found"))?;
1696    response.send(proto::GetPrivateUserInfoResponse {
1697        metrics_id,
1698        staff: user.admin,
1699    })?;
1700    Ok(())
1701}
1702
1703fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1704    match message {
1705        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1706        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1707        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1708        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1709        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1710            code: frame.code.into(),
1711            reason: frame.reason,
1712        })),
1713    }
1714}
1715
1716fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1717    match message {
1718        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1719        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1720        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1721        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1722        AxumMessage::Close(frame) => {
1723            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1724                code: frame.code.into(),
1725                reason: frame.reason,
1726            }))
1727        }
1728    }
1729}
1730
1731fn build_initial_contacts_update(
1732    contacts: Vec<db::Contact>,
1733    pool: &ConnectionPool,
1734) -> proto::UpdateContacts {
1735    let mut update = proto::UpdateContacts::default();
1736
1737    for contact in contacts {
1738        match contact {
1739            db::Contact::Accepted {
1740                user_id,
1741                should_notify,
1742                busy,
1743            } => {
1744                update
1745                    .contacts
1746                    .push(contact_for_user(user_id, should_notify, busy, &pool));
1747            }
1748            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1749            db::Contact::Incoming {
1750                user_id,
1751                should_notify,
1752            } => update
1753                .incoming_requests
1754                .push(proto::IncomingContactRequest {
1755                    requester_id: user_id.to_proto(),
1756                    should_notify,
1757                }),
1758        }
1759    }
1760
1761    update
1762}
1763
1764fn contact_for_user(
1765    user_id: UserId,
1766    should_notify: bool,
1767    busy: bool,
1768    pool: &ConnectionPool,
1769) -> proto::Contact {
1770    proto::Contact {
1771        user_id: user_id.to_proto(),
1772        online: pool.is_user_online(user_id),
1773        busy,
1774        should_notify,
1775    }
1776}
1777
1778fn room_updated(room: &proto::Room, session: &Session) {
1779    for participant in &room.participants {
1780        session
1781            .peer
1782            .send(
1783                ConnectionId(participant.peer_id),
1784                proto::RoomUpdated {
1785                    room: Some(room.clone()),
1786                },
1787            )
1788            .trace_err();
1789    }
1790}
1791
1792async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1793    let db = session.db().await;
1794    let contacts = db.get_contacts(user_id).await?;
1795    let busy = db.is_user_busy(user_id).await?;
1796
1797    let pool = session.connection_pool().await;
1798    let updated_contact = contact_for_user(user_id, false, busy, &pool);
1799    for contact in contacts {
1800        if let db::Contact::Accepted {
1801            user_id: contact_user_id,
1802            ..
1803        } = contact
1804        {
1805            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1806                session
1807                    .peer
1808                    .send(
1809                        contact_conn_id,
1810                        proto::UpdateContacts {
1811                            contacts: vec![updated_contact.clone()],
1812                            remove_contacts: Default::default(),
1813                            incoming_requests: Default::default(),
1814                            remove_incoming_requests: Default::default(),
1815                            outgoing_requests: Default::default(),
1816                            remove_outgoing_requests: Default::default(),
1817                        },
1818                    )
1819                    .trace_err();
1820            }
1821        }
1822    }
1823    Ok(())
1824}
1825
1826async fn leave_room_for_session(session: &Session) -> Result<()> {
1827    let mut contacts_to_update = HashSet::default();
1828
1829    let Some(left_room) = session.db().await.leave_room(session.connection_id).await? else {
1830        return Err(anyhow!("no room to leave"))?;
1831    };
1832    contacts_to_update.insert(session.user_id);
1833
1834    for project in left_room.left_projects.into_values() {
1835        for connection_id in project.connection_ids {
1836            if project.host_user_id == session.user_id {
1837                session
1838                    .peer
1839                    .send(
1840                        connection_id,
1841                        proto::UnshareProject {
1842                            project_id: project.id.to_proto(),
1843                        },
1844                    )
1845                    .trace_err();
1846            } else {
1847                session
1848                    .peer
1849                    .send(
1850                        connection_id,
1851                        proto::RemoveProjectCollaborator {
1852                            project_id: project.id.to_proto(),
1853                            peer_id: session.connection_id.0,
1854                        },
1855                    )
1856                    .trace_err();
1857            }
1858        }
1859
1860        session
1861            .peer
1862            .send(
1863                session.connection_id,
1864                proto::UnshareProject {
1865                    project_id: project.id.to_proto(),
1866                },
1867            )
1868            .trace_err();
1869    }
1870
1871    room_updated(&left_room.room, &session);
1872    {
1873        let pool = session.connection_pool().await;
1874        for canceled_user_id in left_room.canceled_calls_to_user_ids {
1875            for connection_id in pool.user_connection_ids(canceled_user_id) {
1876                session
1877                    .peer
1878                    .send(connection_id, proto::CallCanceled {})
1879                    .trace_err();
1880            }
1881            contacts_to_update.insert(canceled_user_id);
1882        }
1883    }
1884
1885    for contact_user_id in contacts_to_update {
1886        update_user_contacts(contact_user_id, &session).await?;
1887    }
1888
1889    if let Some(live_kit) = session.live_kit_client.as_ref() {
1890        live_kit
1891            .remove_participant(
1892                left_room.room.live_kit_room.clone(),
1893                session.connection_id.to_string(),
1894            )
1895            .await
1896            .trace_err();
1897
1898        if left_room.room.participants.is_empty() {
1899            live_kit
1900                .delete_room(left_room.room.live_kit_room)
1901                .await
1902                .trace_err();
1903        }
1904    }
1905
1906    Ok(())
1907}
1908
1909pub trait ResultExt {
1910    type Ok;
1911
1912    fn trace_err(self) -> Option<Self::Ok>;
1913}
1914
1915impl<T, E> ResultExt for Result<T, E>
1916where
1917    E: std::fmt::Debug,
1918{
1919    type Ok = T;
1920
1921    fn trace_err(self) -> Option<T> {
1922        match self {
1923            Ok(value) => Some(value),
1924            Err(error) => {
1925                tracing::error!("{:?}", error);
1926                None
1927            }
1928        }
1929    }
1930}