rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{self, Database, ProjectId, RoomId, User, UserId},
   6    AppState, Result,
   7};
   8use anyhow::anyhow;
   9use async_tungstenite::tungstenite::{
  10    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  11};
  12use axum::{
  13    body::Body,
  14    extract::{
  15        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  16        ConnectInfo, WebSocketUpgrade,
  17    },
  18    headers::{Header, HeaderName},
  19    http::StatusCode,
  20    middleware,
  21    response::IntoResponse,
  22    routing::get,
  23    Extension, Router, TypedHeader,
  24};
  25use collections::{HashMap, HashSet};
  26pub use connection_pool::ConnectionPool;
  27use futures::{
  28    channel::oneshot,
  29    future::{self, BoxFuture},
  30    stream::FuturesUnordered,
  31    FutureExt, SinkExt, StreamExt, TryStreamExt,
  32};
  33use lazy_static::lazy_static;
  34use prometheus::{register_int_gauge, IntGauge};
  35use rpc::{
  36    proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
  37    Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
  38};
  39use serde::{Serialize, Serializer};
  40use std::{
  41    any::TypeId,
  42    fmt,
  43    future::Future,
  44    marker::PhantomData,
  45    mem,
  46    net::SocketAddr,
  47    ops::{Deref, DerefMut},
  48    rc::Rc,
  49    sync::{
  50        atomic::{AtomicBool, Ordering::SeqCst},
  51        Arc,
  52    },
  53    time::Duration,
  54};
  55use tokio::{
  56    sync::{Mutex, MutexGuard},
  57    time::Sleep,
  58};
  59use tower::ServiceBuilder;
  60use tracing::{info_span, instrument, Instrument};
  61
  62lazy_static! {
  63    static ref METRIC_CONNECTIONS: IntGauge =
  64        register_int_gauge!("connections", "number of connections").unwrap();
  65    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
  66        "shared_projects",
  67        "number of open projects with one or more guests"
  68    )
  69    .unwrap();
  70}
  71
  72type MessageHandler =
  73    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  74
  75struct Response<R> {
  76    peer: Arc<Peer>,
  77    receipt: Receipt<R>,
  78    responded: Arc<AtomicBool>,
  79}
  80
  81impl<R: RequestMessage> Response<R> {
  82    fn send(self, payload: R::Response) -> Result<()> {
  83        self.responded.store(true, SeqCst);
  84        self.peer.respond(self.receipt, payload)?;
  85        Ok(())
  86    }
  87}
  88
  89#[derive(Clone)]
  90struct Session {
  91    user_id: UserId,
  92    connection_id: ConnectionId,
  93    db: Arc<Mutex<DbHandle>>,
  94    peer: Arc<Peer>,
  95    connection_pool: Arc<Mutex<ConnectionPool>>,
  96    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
  97}
  98
  99impl Session {
 100    async fn db(&self) -> MutexGuard<DbHandle> {
 101        #[cfg(test)]
 102        tokio::task::yield_now().await;
 103        let guard = self.db.lock().await;
 104        #[cfg(test)]
 105        tokio::task::yield_now().await;
 106        guard
 107    }
 108
 109    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 110        #[cfg(test)]
 111        tokio::task::yield_now().await;
 112        let guard = self.connection_pool.lock().await;
 113        #[cfg(test)]
 114        tokio::task::yield_now().await;
 115        ConnectionPoolGuard {
 116            guard,
 117            _not_send: PhantomData,
 118        }
 119    }
 120}
 121
 122impl fmt::Debug for Session {
 123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 124        f.debug_struct("Session")
 125            .field("user_id", &self.user_id)
 126            .field("connection_id", &self.connection_id)
 127            .finish()
 128    }
 129}
 130
 131struct DbHandle(Arc<Database>);
 132
 133impl Deref for DbHandle {
 134    type Target = Database;
 135
 136    fn deref(&self) -> &Self::Target {
 137        self.0.as_ref()
 138    }
 139}
 140
 141pub struct Server {
 142    peer: Arc<Peer>,
 143    pub(crate) connection_pool: Arc<Mutex<ConnectionPool>>,
 144    app_state: Arc<AppState>,
 145    handlers: HashMap<TypeId, MessageHandler>,
 146}
 147
 148pub trait Executor: Send + Clone {
 149    type Sleep: Send + Future;
 150    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
 151    fn sleep(&self, duration: Duration) -> Self::Sleep;
 152}
 153
 154#[derive(Clone)]
 155pub struct RealExecutor;
 156
 157pub(crate) struct ConnectionPoolGuard<'a> {
 158    guard: MutexGuard<'a, ConnectionPool>,
 159    _not_send: PhantomData<Rc<()>>,
 160}
 161
 162#[derive(Serialize)]
 163pub struct ServerSnapshot<'a> {
 164    peer: &'a Peer,
 165    #[serde(serialize_with = "serialize_deref")]
 166    connection_pool: ConnectionPoolGuard<'a>,
 167}
 168
 169pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 170where
 171    S: Serializer,
 172    T: Deref<Target = U>,
 173    U: Serialize,
 174{
 175    Serialize::serialize(value.deref(), serializer)
 176}
 177
 178impl Server {
 179    pub fn new(app_state: Arc<AppState>) -> Arc<Self> {
 180        let mut server = Self {
 181            peer: Peer::new(),
 182            app_state,
 183            connection_pool: Default::default(),
 184            handlers: Default::default(),
 185        };
 186
 187        server
 188            .add_request_handler(ping)
 189            .add_request_handler(create_room)
 190            .add_request_handler(join_room)
 191            .add_message_handler(leave_room)
 192            .add_request_handler(call)
 193            .add_request_handler(cancel_call)
 194            .add_message_handler(decline_call)
 195            .add_request_handler(update_participant_location)
 196            .add_request_handler(share_project)
 197            .add_message_handler(unshare_project)
 198            .add_request_handler(join_project)
 199            .add_message_handler(leave_project)
 200            .add_request_handler(update_project)
 201            .add_request_handler(update_worktree)
 202            .add_message_handler(start_language_server)
 203            .add_message_handler(update_language_server)
 204            .add_message_handler(update_diagnostic_summary)
 205            .add_request_handler(forward_project_request::<proto::GetHover>)
 206            .add_request_handler(forward_project_request::<proto::GetDefinition>)
 207            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
 208            .add_request_handler(forward_project_request::<proto::GetReferences>)
 209            .add_request_handler(forward_project_request::<proto::SearchProject>)
 210            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
 211            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
 212            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
 213            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
 214            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
 215            .add_request_handler(forward_project_request::<proto::GetCompletions>)
 216            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
 217            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
 218            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
 219            .add_request_handler(forward_project_request::<proto::PrepareRename>)
 220            .add_request_handler(forward_project_request::<proto::PerformRename>)
 221            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
 222            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
 223            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
 224            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
 225            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
 226            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
 227            .add_message_handler(create_buffer_for_peer)
 228            .add_request_handler(update_buffer)
 229            .add_message_handler(update_buffer_file)
 230            .add_message_handler(buffer_reloaded)
 231            .add_message_handler(buffer_saved)
 232            .add_request_handler(save_buffer)
 233            .add_request_handler(get_users)
 234            .add_request_handler(fuzzy_search_users)
 235            .add_request_handler(request_contact)
 236            .add_request_handler(remove_contact)
 237            .add_request_handler(respond_to_contact_request)
 238            .add_request_handler(follow)
 239            .add_message_handler(unfollow)
 240            .add_message_handler(update_followers)
 241            .add_message_handler(update_diff_base)
 242            .add_request_handler(get_private_user_info);
 243
 244        Arc::new(server)
 245    }
 246
 247    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 248    where
 249        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 250        Fut: 'static + Send + Future<Output = Result<()>>,
 251        M: EnvelopedMessage,
 252    {
 253        let prev_handler = self.handlers.insert(
 254            TypeId::of::<M>(),
 255            Box::new(move |envelope, session| {
 256                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 257                let span = info_span!(
 258                    "handle message",
 259                    payload_type = envelope.payload_type_name()
 260                );
 261                span.in_scope(|| {
 262                    tracing::info!(
 263                        payload_type = envelope.payload_type_name(),
 264                        "message received"
 265                    );
 266                });
 267                let future = (handler)(*envelope, session);
 268                async move {
 269                    if let Err(error) = future.await {
 270                        tracing::error!(%error, "error handling message");
 271                    }
 272                }
 273                .instrument(span)
 274                .boxed()
 275            }),
 276        );
 277        if prev_handler.is_some() {
 278            panic!("registered a handler for the same message twice");
 279        }
 280        self
 281    }
 282
 283    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 284    where
 285        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 286        Fut: 'static + Send + Future<Output = Result<()>>,
 287        M: EnvelopedMessage,
 288    {
 289        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 290        self
 291    }
 292
 293    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 294    where
 295        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 296        Fut: Send + Future<Output = Result<()>>,
 297        M: RequestMessage,
 298    {
 299        let handler = Arc::new(handler);
 300        self.add_handler(move |envelope, session| {
 301            let receipt = envelope.receipt();
 302            let handler = handler.clone();
 303            async move {
 304                let peer = session.peer.clone();
 305                let responded = Arc::new(AtomicBool::default());
 306                let response = Response {
 307                    peer: peer.clone(),
 308                    responded: responded.clone(),
 309                    receipt,
 310                };
 311                match (handler)(envelope.payload, response, session).await {
 312                    Ok(()) => {
 313                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 314                            Ok(())
 315                        } else {
 316                            Err(anyhow!("handler did not send a response"))?
 317                        }
 318                    }
 319                    Err(error) => {
 320                        peer.respond_with_error(
 321                            receipt,
 322                            proto::Error {
 323                                message: error.to_string(),
 324                            },
 325                        )?;
 326                        Err(error)
 327                    }
 328                }
 329            }
 330        })
 331    }
 332
 333    pub fn handle_connection<E: Executor>(
 334        self: &Arc<Self>,
 335        connection: Connection,
 336        address: String,
 337        user: User,
 338        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 339        executor: E,
 340    ) -> impl Future<Output = Result<()>> {
 341        let this = self.clone();
 342        let user_id = user.id;
 343        let login = user.github_login;
 344        let span = info_span!("handle connection", %user_id, %login, %address);
 345        async move {
 346            let (connection_id, handle_io, mut incoming_rx) = this
 347                .peer
 348                .add_connection(connection, {
 349                    let executor = executor.clone();
 350                    move |duration| {
 351                        let timer = executor.sleep(duration);
 352                        async move {
 353                            timer.await;
 354                        }
 355                    }
 356                });
 357
 358            tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
 359            this.peer.send(connection_id, proto::Hello { peer_id: connection_id.0 })?;
 360            tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
 361
 362            if let Some(send_connection_id) = send_connection_id.take() {
 363                let _ = send_connection_id.send(connection_id);
 364            }
 365
 366            if !user.connected_once {
 367                this.peer.send(connection_id, proto::ShowContacts {})?;
 368                this.app_state.db.set_user_connected_once(user_id, true).await?;
 369            }
 370
 371            let (contacts, invite_code) = future::try_join(
 372                this.app_state.db.get_contacts(user_id),
 373                this.app_state.db.get_invite_code_for_user(user_id)
 374            ).await?;
 375
 376            {
 377                let mut pool = this.connection_pool.lock().await;
 378                pool.add_connection(connection_id, user_id, user.admin);
 379                this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
 380
 381                if let Some((code, count)) = invite_code {
 382                    this.peer.send(connection_id, proto::UpdateInviteInfo {
 383                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
 384                        count: count as u32,
 385                    })?;
 386                }
 387            }
 388
 389            if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
 390                this.peer.send(connection_id, incoming_call)?;
 391            }
 392
 393            let session = Session {
 394                user_id,
 395                connection_id,
 396                db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))),
 397                peer: this.peer.clone(),
 398                connection_pool: this.connection_pool.clone(),
 399                live_kit_client: this.app_state.live_kit_client.clone()
 400            };
 401            update_user_contacts(user_id, &session).await?;
 402
 403            let handle_io = handle_io.fuse();
 404            futures::pin_mut!(handle_io);
 405
 406            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 407            // This prevents deadlocks when e.g., client A performs a request to client B and
 408            // client B performs a request to client A. If both clients stop processing further
 409            // messages until their respective request completes, they won't have a chance to
 410            // respond to the other client's request and cause a deadlock.
 411            //
 412            // This arrangement ensures we will attempt to process earlier messages first, but fall
 413            // back to processing messages arrived later in the spirit of making progress.
 414            let mut foreground_message_handlers = FuturesUnordered::new();
 415            loop {
 416                let next_message = incoming_rx.next().fuse();
 417                futures::pin_mut!(next_message);
 418                futures::select_biased! {
 419                    result = handle_io => {
 420                        if let Err(error) = result {
 421                            tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
 422                        }
 423                        break;
 424                    }
 425                    _ = foreground_message_handlers.next() => {}
 426                    message = next_message => {
 427                        if let Some(message) = message {
 428                            let type_name = message.payload_type_name();
 429                            let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
 430                            let span_enter = span.enter();
 431                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 432                                let is_background = message.is_background();
 433                                let handle_message = (handler)(message, session.clone());
 434                                drop(span_enter);
 435
 436                                let handle_message = handle_message.instrument(span);
 437                                if is_background {
 438                                    executor.spawn_detached(handle_message);
 439                                } else {
 440                                    foreground_message_handlers.push(handle_message);
 441                                }
 442                            } else {
 443                                tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
 444                            }
 445                        } else {
 446                            tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
 447                            break;
 448                        }
 449                    }
 450                }
 451            }
 452
 453            drop(foreground_message_handlers);
 454            tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
 455            if let Err(error) = sign_out(session).await {
 456                tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
 457            }
 458
 459            Ok(())
 460        }.instrument(span)
 461    }
 462
 463    pub async fn invite_code_redeemed(
 464        self: &Arc<Self>,
 465        inviter_id: UserId,
 466        invitee_id: UserId,
 467    ) -> Result<()> {
 468        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 469            if let Some(code) = &user.invite_code {
 470                let pool = self.connection_pool.lock().await;
 471                let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
 472                for connection_id in pool.user_connection_ids(inviter_id) {
 473                    self.peer.send(
 474                        connection_id,
 475                        proto::UpdateContacts {
 476                            contacts: vec![invitee_contact.clone()],
 477                            ..Default::default()
 478                        },
 479                    )?;
 480                    self.peer.send(
 481                        connection_id,
 482                        proto::UpdateInviteInfo {
 483                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 484                            count: user.invite_count as u32,
 485                        },
 486                    )?;
 487                }
 488            }
 489        }
 490        Ok(())
 491    }
 492
 493    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 494        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 495            if let Some(invite_code) = &user.invite_code {
 496                let pool = self.connection_pool.lock().await;
 497                for connection_id in pool.user_connection_ids(user_id) {
 498                    self.peer.send(
 499                        connection_id,
 500                        proto::UpdateInviteInfo {
 501                            url: format!(
 502                                "{}{}",
 503                                self.app_state.config.invite_link_prefix, invite_code
 504                            ),
 505                            count: user.invite_count as u32,
 506                        },
 507                    )?;
 508                }
 509            }
 510        }
 511        Ok(())
 512    }
 513
 514    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 515        ServerSnapshot {
 516            connection_pool: ConnectionPoolGuard {
 517                guard: self.connection_pool.lock().await,
 518                _not_send: PhantomData,
 519            },
 520            peer: &self.peer,
 521        }
 522    }
 523}
 524
 525impl<'a> Deref for ConnectionPoolGuard<'a> {
 526    type Target = ConnectionPool;
 527
 528    fn deref(&self) -> &Self::Target {
 529        &*self.guard
 530    }
 531}
 532
 533impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 534    fn deref_mut(&mut self) -> &mut Self::Target {
 535        &mut *self.guard
 536    }
 537}
 538
 539impl<'a> Drop for ConnectionPoolGuard<'a> {
 540    fn drop(&mut self) {
 541        #[cfg(test)]
 542        self.check_invariants();
 543    }
 544}
 545
 546impl Executor for RealExecutor {
 547    type Sleep = Sleep;
 548
 549    fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
 550        tokio::task::spawn(future);
 551    }
 552
 553    fn sleep(&self, duration: Duration) -> Self::Sleep {
 554        tokio::time::sleep(duration)
 555    }
 556}
 557
 558fn broadcast<F>(
 559    sender_id: ConnectionId,
 560    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 561    mut f: F,
 562) where
 563    F: FnMut(ConnectionId) -> anyhow::Result<()>,
 564{
 565    for receiver_id in receiver_ids {
 566        if receiver_id != sender_id {
 567            f(receiver_id).trace_err();
 568        }
 569    }
 570}
 571
 572lazy_static! {
 573    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
 574}
 575
 576pub struct ProtocolVersion(u32);
 577
 578impl Header for ProtocolVersion {
 579    fn name() -> &'static HeaderName {
 580        &ZED_PROTOCOL_VERSION
 581    }
 582
 583    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
 584    where
 585        Self: Sized,
 586        I: Iterator<Item = &'i axum::http::HeaderValue>,
 587    {
 588        let version = values
 589            .next()
 590            .ok_or_else(axum::headers::Error::invalid)?
 591            .to_str()
 592            .map_err(|_| axum::headers::Error::invalid())?
 593            .parse()
 594            .map_err(|_| axum::headers::Error::invalid())?;
 595        Ok(Self(version))
 596    }
 597
 598    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
 599        values.extend([self.0.to_string().parse().unwrap()]);
 600    }
 601}
 602
 603pub fn routes(server: Arc<Server>) -> Router<Body> {
 604    Router::new()
 605        .route("/rpc", get(handle_websocket_request))
 606        .layer(
 607            ServiceBuilder::new()
 608                .layer(Extension(server.app_state.clone()))
 609                .layer(middleware::from_fn(auth::validate_header)),
 610        )
 611        .route("/metrics", get(handle_metrics))
 612        .layer(Extension(server))
 613}
 614
 615pub async fn handle_websocket_request(
 616    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
 617    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
 618    Extension(server): Extension<Arc<Server>>,
 619    Extension(user): Extension<User>,
 620    ws: WebSocketUpgrade,
 621) -> axum::response::Response {
 622    if protocol_version != rpc::PROTOCOL_VERSION {
 623        return (
 624            StatusCode::UPGRADE_REQUIRED,
 625            "client must be upgraded".to_string(),
 626        )
 627            .into_response();
 628    }
 629    let socket_address = socket_address.to_string();
 630    ws.on_upgrade(move |socket| {
 631        use util::ResultExt;
 632        let socket = socket
 633            .map_ok(to_tungstenite_message)
 634            .err_into()
 635            .with(|message| async move { Ok(to_axum_message(message)) });
 636        let connection = Connection::new(Box::pin(socket));
 637        async move {
 638            server
 639                .handle_connection(connection, socket_address, user, None, RealExecutor)
 640                .await
 641                .log_err();
 642        }
 643    })
 644}
 645
 646pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
 647    let connections = server
 648        .connection_pool
 649        .lock()
 650        .await
 651        .connections()
 652        .filter(|connection| !connection.admin)
 653        .count();
 654
 655    METRIC_CONNECTIONS.set(connections as _);
 656
 657    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
 658    METRIC_SHARED_PROJECTS.set(shared_projects as _);
 659
 660    let encoder = prometheus::TextEncoder::new();
 661    let metric_families = prometheus::gather();
 662    let encoded_metrics = encoder
 663        .encode_to_string(&metric_families)
 664        .map_err(|err| anyhow!("{}", err))?;
 665    Ok(encoded_metrics)
 666}
 667
 668#[instrument(err)]
 669async fn sign_out(session: Session) -> Result<()> {
 670    session.peer.disconnect(session.connection_id);
 671    let decline_calls = {
 672        let mut pool = session.connection_pool().await;
 673        pool.remove_connection(session.connection_id)?;
 674        let mut connections = pool.user_connection_ids(session.user_id);
 675        connections.next().is_none()
 676    };
 677
 678    leave_room_for_session(&session).await.trace_err();
 679    if decline_calls {
 680        if let Some(room) = session
 681            .db()
 682            .await
 683            .decline_call(None, session.user_id)
 684            .await
 685            .trace_err()
 686        {
 687            room_updated(&room, &session);
 688        }
 689    }
 690
 691    update_user_contacts(session.user_id, &session).await?;
 692
 693    Ok(())
 694}
 695
 696async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
 697    response.send(proto::Ack {})?;
 698    Ok(())
 699}
 700
 701async fn create_room(
 702    _request: proto::CreateRoom,
 703    response: Response<proto::CreateRoom>,
 704    session: Session,
 705) -> Result<()> {
 706    let live_kit_room = nanoid::nanoid!(30);
 707    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 708        if let Some(_) = live_kit
 709            .create_room(live_kit_room.clone())
 710            .await
 711            .trace_err()
 712        {
 713            if let Some(token) = live_kit
 714                .room_token(&live_kit_room, &session.connection_id.to_string())
 715                .trace_err()
 716            {
 717                Some(proto::LiveKitConnectionInfo {
 718                    server_url: live_kit.url().into(),
 719                    token,
 720                })
 721            } else {
 722                None
 723            }
 724        } else {
 725            None
 726        }
 727    } else {
 728        None
 729    };
 730
 731    {
 732        let room = session
 733            .db()
 734            .await
 735            .create_room(session.user_id, session.connection_id, &live_kit_room)
 736            .await?;
 737
 738        response.send(proto::CreateRoomResponse {
 739            room: Some(room.clone()),
 740            live_kit_connection_info,
 741        })?;
 742    }
 743
 744    update_user_contacts(session.user_id, &session).await?;
 745    Ok(())
 746}
 747
 748async fn join_room(
 749    request: proto::JoinRoom,
 750    response: Response<proto::JoinRoom>,
 751    session: Session,
 752) -> Result<()> {
 753    let room = {
 754        let room = session
 755            .db()
 756            .await
 757            .join_room(
 758                RoomId::from_proto(request.id),
 759                session.user_id,
 760                session.connection_id,
 761            )
 762            .await?;
 763        room_updated(&room, &session);
 764        room.clone()
 765    };
 766
 767    for connection_id in session
 768        .connection_pool()
 769        .await
 770        .user_connection_ids(session.user_id)
 771    {
 772        session
 773            .peer
 774            .send(connection_id, proto::CallCanceled {})
 775            .trace_err();
 776    }
 777
 778    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
 779        if let Some(token) = live_kit
 780            .room_token(&room.live_kit_room, &session.connection_id.to_string())
 781            .trace_err()
 782        {
 783            Some(proto::LiveKitConnectionInfo {
 784                server_url: live_kit.url().into(),
 785                token,
 786            })
 787        } else {
 788            None
 789        }
 790    } else {
 791        None
 792    };
 793
 794    response.send(proto::JoinRoomResponse {
 795        room: Some(room),
 796        live_kit_connection_info,
 797    })?;
 798
 799    update_user_contacts(session.user_id, &session).await?;
 800    Ok(())
 801}
 802
 803async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> {
 804    leave_room_for_session(&session).await
 805}
 806
 807async fn call(
 808    request: proto::Call,
 809    response: Response<proto::Call>,
 810    session: Session,
 811) -> Result<()> {
 812    let room_id = RoomId::from_proto(request.room_id);
 813    let calling_user_id = session.user_id;
 814    let calling_connection_id = session.connection_id;
 815    let called_user_id = UserId::from_proto(request.called_user_id);
 816    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
 817    if !session
 818        .db()
 819        .await
 820        .has_contact(calling_user_id, called_user_id)
 821        .await?
 822    {
 823        return Err(anyhow!("cannot call a user who isn't a contact"))?;
 824    }
 825
 826    let incoming_call = {
 827        let (room, incoming_call) = &mut *session
 828            .db()
 829            .await
 830            .call(
 831                room_id,
 832                calling_user_id,
 833                calling_connection_id,
 834                called_user_id,
 835                initial_project_id,
 836            )
 837            .await?;
 838        room_updated(&room, &session);
 839        mem::take(incoming_call)
 840    };
 841    update_user_contacts(called_user_id, &session).await?;
 842
 843    let mut calls = session
 844        .connection_pool()
 845        .await
 846        .user_connection_ids(called_user_id)
 847        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
 848        .collect::<FuturesUnordered<_>>();
 849
 850    while let Some(call_response) = calls.next().await {
 851        match call_response.as_ref() {
 852            Ok(_) => {
 853                response.send(proto::Ack {})?;
 854                return Ok(());
 855            }
 856            Err(_) => {
 857                call_response.trace_err();
 858            }
 859        }
 860    }
 861
 862    {
 863        let room = session
 864            .db()
 865            .await
 866            .call_failed(room_id, called_user_id)
 867            .await?;
 868        room_updated(&room, &session);
 869    }
 870    update_user_contacts(called_user_id, &session).await?;
 871
 872    Err(anyhow!("failed to ring user"))?
 873}
 874
 875async fn cancel_call(
 876    request: proto::CancelCall,
 877    response: Response<proto::CancelCall>,
 878    session: Session,
 879) -> Result<()> {
 880    let called_user_id = UserId::from_proto(request.called_user_id);
 881    let room_id = RoomId::from_proto(request.room_id);
 882    {
 883        let room = session
 884            .db()
 885            .await
 886            .cancel_call(Some(room_id), session.connection_id, called_user_id)
 887            .await?;
 888        room_updated(&room, &session);
 889    }
 890
 891    for connection_id in session
 892        .connection_pool()
 893        .await
 894        .user_connection_ids(called_user_id)
 895    {
 896        session
 897            .peer
 898            .send(connection_id, proto::CallCanceled {})
 899            .trace_err();
 900    }
 901    response.send(proto::Ack {})?;
 902
 903    update_user_contacts(called_user_id, &session).await?;
 904    Ok(())
 905}
 906
 907async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
 908    let room_id = RoomId::from_proto(message.room_id);
 909    {
 910        let room = session
 911            .db()
 912            .await
 913            .decline_call(Some(room_id), session.user_id)
 914            .await?;
 915        room_updated(&room, &session);
 916    }
 917
 918    for connection_id in session
 919        .connection_pool()
 920        .await
 921        .user_connection_ids(session.user_id)
 922    {
 923        session
 924            .peer
 925            .send(connection_id, proto::CallCanceled {})
 926            .trace_err();
 927    }
 928    update_user_contacts(session.user_id, &session).await?;
 929    Ok(())
 930}
 931
 932async fn update_participant_location(
 933    request: proto::UpdateParticipantLocation,
 934    response: Response<proto::UpdateParticipantLocation>,
 935    session: Session,
 936) -> Result<()> {
 937    let room_id = RoomId::from_proto(request.room_id);
 938    let location = request
 939        .location
 940        .ok_or_else(|| anyhow!("invalid location"))?;
 941    let room = session
 942        .db()
 943        .await
 944        .update_room_participant_location(room_id, session.connection_id, location)
 945        .await?;
 946    room_updated(&room, &session);
 947    response.send(proto::Ack {})?;
 948    Ok(())
 949}
 950
 951async fn share_project(
 952    request: proto::ShareProject,
 953    response: Response<proto::ShareProject>,
 954    session: Session,
 955) -> Result<()> {
 956    let (project_id, room) = &*session
 957        .db()
 958        .await
 959        .share_project(
 960            RoomId::from_proto(request.room_id),
 961            session.connection_id,
 962            &request.worktrees,
 963        )
 964        .await?;
 965    response.send(proto::ShareProjectResponse {
 966        project_id: project_id.to_proto(),
 967    })?;
 968    room_updated(&room, &session);
 969
 970    Ok(())
 971}
 972
 973async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
 974    let project_id = ProjectId::from_proto(message.project_id);
 975
 976    let (room, guest_connection_ids) = &*session
 977        .db()
 978        .await
 979        .unshare_project(project_id, session.connection_id)
 980        .await?;
 981
 982    broadcast(
 983        session.connection_id,
 984        guest_connection_ids.iter().copied(),
 985        |conn_id| session.peer.send(conn_id, message.clone()),
 986    );
 987    room_updated(&room, &session);
 988
 989    Ok(())
 990}
 991
 992async fn join_project(
 993    request: proto::JoinProject,
 994    response: Response<proto::JoinProject>,
 995    session: Session,
 996) -> Result<()> {
 997    let project_id = ProjectId::from_proto(request.project_id);
 998    let guest_user_id = session.user_id;
 999
1000    tracing::info!(%project_id, "join project");
1001
1002    let (project, replica_id) = &mut *session
1003        .db()
1004        .await
1005        .join_project(project_id, session.connection_id)
1006        .await?;
1007
1008    let collaborators = project
1009        .collaborators
1010        .iter()
1011        .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32)
1012        .map(|collaborator| proto::Collaborator {
1013            peer_id: collaborator.connection_id as u32,
1014            replica_id: collaborator.replica_id.0 as u32,
1015            user_id: collaborator.user_id.to_proto(),
1016        })
1017        .collect::<Vec<_>>();
1018    let worktrees = project
1019        .worktrees
1020        .iter()
1021        .map(|(id, worktree)| proto::WorktreeMetadata {
1022            id: *id,
1023            root_name: worktree.root_name.clone(),
1024            visible: worktree.visible,
1025            abs_path: worktree.abs_path.clone(),
1026        })
1027        .collect::<Vec<_>>();
1028
1029    for collaborator in &collaborators {
1030        session
1031            .peer
1032            .send(
1033                ConnectionId(collaborator.peer_id),
1034                proto::AddProjectCollaborator {
1035                    project_id: project_id.to_proto(),
1036                    collaborator: Some(proto::Collaborator {
1037                        peer_id: session.connection_id.0,
1038                        replica_id: replica_id.0 as u32,
1039                        user_id: guest_user_id.to_proto(),
1040                    }),
1041                },
1042            )
1043            .trace_err();
1044    }
1045
1046    // First, we send the metadata associated with each worktree.
1047    response.send(proto::JoinProjectResponse {
1048        worktrees: worktrees.clone(),
1049        replica_id: replica_id.0 as u32,
1050        collaborators: collaborators.clone(),
1051        language_servers: project.language_servers.clone(),
1052    })?;
1053
1054    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1055        #[cfg(any(test, feature = "test-support"))]
1056        const MAX_CHUNK_SIZE: usize = 2;
1057        #[cfg(not(any(test, feature = "test-support")))]
1058        const MAX_CHUNK_SIZE: usize = 256;
1059
1060        // Stream this worktree's entries.
1061        let message = proto::UpdateWorktree {
1062            project_id: project_id.to_proto(),
1063            worktree_id,
1064            abs_path: worktree.abs_path.clone(),
1065            root_name: worktree.root_name,
1066            updated_entries: worktree.entries,
1067            removed_entries: Default::default(),
1068            scan_id: worktree.scan_id,
1069            is_last_update: worktree.is_complete,
1070        };
1071        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1072            session.peer.send(session.connection_id, update.clone())?;
1073        }
1074
1075        // Stream this worktree's diagnostics.
1076        for summary in worktree.diagnostic_summaries {
1077            session.peer.send(
1078                session.connection_id,
1079                proto::UpdateDiagnosticSummary {
1080                    project_id: project_id.to_proto(),
1081                    worktree_id: worktree.id,
1082                    summary: Some(summary),
1083                },
1084            )?;
1085        }
1086    }
1087
1088    for language_server in &project.language_servers {
1089        session.peer.send(
1090            session.connection_id,
1091            proto::UpdateLanguageServer {
1092                project_id: project_id.to_proto(),
1093                language_server_id: language_server.id,
1094                variant: Some(
1095                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1096                        proto::LspDiskBasedDiagnosticsUpdated {},
1097                    ),
1098                ),
1099            },
1100        )?;
1101    }
1102
1103    Ok(())
1104}
1105
1106async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1107    let sender_id = session.connection_id;
1108    let project_id = ProjectId::from_proto(request.project_id);
1109
1110    let project = session
1111        .db()
1112        .await
1113        .leave_project(project_id, sender_id)
1114        .await?;
1115    tracing::info!(
1116        %project_id,
1117        host_user_id = %project.host_user_id,
1118        host_connection_id = %project.host_connection_id,
1119        "leave project"
1120    );
1121
1122    broadcast(
1123        sender_id,
1124        project.connection_ids.iter().copied(),
1125        |conn_id| {
1126            session.peer.send(
1127                conn_id,
1128                proto::RemoveProjectCollaborator {
1129                    project_id: project_id.to_proto(),
1130                    peer_id: sender_id.0,
1131                },
1132            )
1133        },
1134    );
1135
1136    Ok(())
1137}
1138
1139async fn update_project(
1140    request: proto::UpdateProject,
1141    response: Response<proto::UpdateProject>,
1142    session: Session,
1143) -> Result<()> {
1144    let project_id = ProjectId::from_proto(request.project_id);
1145    let (room, guest_connection_ids) = &*session
1146        .db()
1147        .await
1148        .update_project(project_id, session.connection_id, &request.worktrees)
1149        .await?;
1150    broadcast(
1151        session.connection_id,
1152        guest_connection_ids.iter().copied(),
1153        |connection_id| {
1154            session
1155                .peer
1156                .forward_send(session.connection_id, connection_id, request.clone())
1157        },
1158    );
1159    room_updated(&room, &session);
1160    response.send(proto::Ack {})?;
1161
1162    Ok(())
1163}
1164
1165async fn update_worktree(
1166    request: proto::UpdateWorktree,
1167    response: Response<proto::UpdateWorktree>,
1168    session: Session,
1169) -> Result<()> {
1170    let guest_connection_ids = session
1171        .db()
1172        .await
1173        .update_worktree(&request, session.connection_id)
1174        .await?;
1175
1176    broadcast(
1177        session.connection_id,
1178        guest_connection_ids.iter().copied(),
1179        |connection_id| {
1180            session
1181                .peer
1182                .forward_send(session.connection_id, connection_id, request.clone())
1183        },
1184    );
1185    response.send(proto::Ack {})?;
1186    Ok(())
1187}
1188
1189async fn update_diagnostic_summary(
1190    message: proto::UpdateDiagnosticSummary,
1191    session: Session,
1192) -> Result<()> {
1193    let guest_connection_ids = session
1194        .db()
1195        .await
1196        .update_diagnostic_summary(&message, session.connection_id)
1197        .await?;
1198
1199    broadcast(
1200        session.connection_id,
1201        guest_connection_ids.iter().copied(),
1202        |connection_id| {
1203            session
1204                .peer
1205                .forward_send(session.connection_id, connection_id, message.clone())
1206        },
1207    );
1208
1209    Ok(())
1210}
1211
1212async fn start_language_server(
1213    request: proto::StartLanguageServer,
1214    session: Session,
1215) -> Result<()> {
1216    let guest_connection_ids = session
1217        .db()
1218        .await
1219        .start_language_server(&request, session.connection_id)
1220        .await?;
1221
1222    broadcast(
1223        session.connection_id,
1224        guest_connection_ids.iter().copied(),
1225        |connection_id| {
1226            session
1227                .peer
1228                .forward_send(session.connection_id, connection_id, request.clone())
1229        },
1230    );
1231    Ok(())
1232}
1233
1234async fn update_language_server(
1235    request: proto::UpdateLanguageServer,
1236    session: Session,
1237) -> Result<()> {
1238    let project_id = ProjectId::from_proto(request.project_id);
1239    let project_connection_ids = session
1240        .db()
1241        .await
1242        .project_connection_ids(project_id, session.connection_id)
1243        .await?;
1244    broadcast(
1245        session.connection_id,
1246        project_connection_ids.iter().copied(),
1247        |connection_id| {
1248            session
1249                .peer
1250                .forward_send(session.connection_id, connection_id, request.clone())
1251        },
1252    );
1253    Ok(())
1254}
1255
1256async fn forward_project_request<T>(
1257    request: T,
1258    response: Response<T>,
1259    session: Session,
1260) -> Result<()>
1261where
1262    T: EntityMessage + RequestMessage,
1263{
1264    let project_id = ProjectId::from_proto(request.remote_entity_id());
1265    let host_connection_id = {
1266        let collaborators = session
1267            .db()
1268            .await
1269            .project_collaborators(project_id, session.connection_id)
1270            .await?;
1271        ConnectionId(
1272            collaborators
1273                .iter()
1274                .find(|collaborator| collaborator.is_host)
1275                .ok_or_else(|| anyhow!("host not found"))?
1276                .connection_id as u32,
1277        )
1278    };
1279
1280    let payload = session
1281        .peer
1282        .forward_request(session.connection_id, host_connection_id, request)
1283        .await?;
1284
1285    response.send(payload)?;
1286    Ok(())
1287}
1288
1289async fn save_buffer(
1290    request: proto::SaveBuffer,
1291    response: Response<proto::SaveBuffer>,
1292    session: Session,
1293) -> Result<()> {
1294    let project_id = ProjectId::from_proto(request.project_id);
1295    let host_connection_id = {
1296        let collaborators = session
1297            .db()
1298            .await
1299            .project_collaborators(project_id, session.connection_id)
1300            .await?;
1301        let host = collaborators
1302            .iter()
1303            .find(|collaborator| collaborator.is_host)
1304            .ok_or_else(|| anyhow!("host not found"))?;
1305        ConnectionId(host.connection_id as u32)
1306    };
1307    let response_payload = session
1308        .peer
1309        .forward_request(session.connection_id, host_connection_id, request.clone())
1310        .await?;
1311
1312    let mut collaborators = session
1313        .db()
1314        .await
1315        .project_collaborators(project_id, session.connection_id)
1316        .await?;
1317    collaborators
1318        .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32);
1319    let project_connection_ids = collaborators
1320        .iter()
1321        .map(|collaborator| ConnectionId(collaborator.connection_id as u32));
1322    broadcast(host_connection_id, project_connection_ids, |conn_id| {
1323        session
1324            .peer
1325            .forward_send(host_connection_id, conn_id, response_payload.clone())
1326    });
1327    response.send(response_payload)?;
1328    Ok(())
1329}
1330
1331async fn create_buffer_for_peer(
1332    request: proto::CreateBufferForPeer,
1333    session: Session,
1334) -> Result<()> {
1335    session.peer.forward_send(
1336        session.connection_id,
1337        ConnectionId(request.peer_id),
1338        request,
1339    )?;
1340    Ok(())
1341}
1342
1343async fn update_buffer(
1344    request: proto::UpdateBuffer,
1345    response: Response<proto::UpdateBuffer>,
1346    session: Session,
1347) -> Result<()> {
1348    let project_id = ProjectId::from_proto(request.project_id);
1349    let project_connection_ids = session
1350        .db()
1351        .await
1352        .project_connection_ids(project_id, session.connection_id)
1353        .await?;
1354
1355    broadcast(
1356        session.connection_id,
1357        project_connection_ids.iter().copied(),
1358        |connection_id| {
1359            session
1360                .peer
1361                .forward_send(session.connection_id, connection_id, request.clone())
1362        },
1363    );
1364    response.send(proto::Ack {})?;
1365    Ok(())
1366}
1367
1368async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
1369    let project_id = ProjectId::from_proto(request.project_id);
1370    let project_connection_ids = session
1371        .db()
1372        .await
1373        .project_connection_ids(project_id, session.connection_id)
1374        .await?;
1375
1376    broadcast(
1377        session.connection_id,
1378        project_connection_ids.iter().copied(),
1379        |connection_id| {
1380            session
1381                .peer
1382                .forward_send(session.connection_id, connection_id, request.clone())
1383        },
1384    );
1385    Ok(())
1386}
1387
1388async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
1389    let project_id = ProjectId::from_proto(request.project_id);
1390    let project_connection_ids = session
1391        .db()
1392        .await
1393        .project_connection_ids(project_id, session.connection_id)
1394        .await?;
1395    broadcast(
1396        session.connection_id,
1397        project_connection_ids.iter().copied(),
1398        |connection_id| {
1399            session
1400                .peer
1401                .forward_send(session.connection_id, connection_id, request.clone())
1402        },
1403    );
1404    Ok(())
1405}
1406
1407async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
1408    let project_id = ProjectId::from_proto(request.project_id);
1409    let project_connection_ids = session
1410        .db()
1411        .await
1412        .project_connection_ids(project_id, session.connection_id)
1413        .await?;
1414    broadcast(
1415        session.connection_id,
1416        project_connection_ids.iter().copied(),
1417        |connection_id| {
1418            session
1419                .peer
1420                .forward_send(session.connection_id, connection_id, request.clone())
1421        },
1422    );
1423    Ok(())
1424}
1425
1426async fn follow(
1427    request: proto::Follow,
1428    response: Response<proto::Follow>,
1429    session: Session,
1430) -> Result<()> {
1431    let project_id = ProjectId::from_proto(request.project_id);
1432    let leader_id = ConnectionId(request.leader_id);
1433    let follower_id = session.connection_id;
1434    {
1435        let project_connection_ids = session
1436            .db()
1437            .await
1438            .project_connection_ids(project_id, session.connection_id)
1439            .await?;
1440
1441        if !project_connection_ids.contains(&leader_id) {
1442            Err(anyhow!("no such peer"))?;
1443        }
1444    }
1445
1446    let mut response_payload = session
1447        .peer
1448        .forward_request(session.connection_id, leader_id, request)
1449        .await?;
1450    response_payload
1451        .views
1452        .retain(|view| view.leader_id != Some(follower_id.0));
1453    response.send(response_payload)?;
1454    Ok(())
1455}
1456
1457async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
1458    let project_id = ProjectId::from_proto(request.project_id);
1459    let leader_id = ConnectionId(request.leader_id);
1460    let project_connection_ids = session
1461        .db()
1462        .await
1463        .project_connection_ids(project_id, session.connection_id)
1464        .await?;
1465    if !project_connection_ids.contains(&leader_id) {
1466        Err(anyhow!("no such peer"))?;
1467    }
1468    session
1469        .peer
1470        .forward_send(session.connection_id, leader_id, request)?;
1471    Ok(())
1472}
1473
1474async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
1475    let project_id = ProjectId::from_proto(request.project_id);
1476    let project_connection_ids = session
1477        .db
1478        .lock()
1479        .await
1480        .project_connection_ids(project_id, session.connection_id)
1481        .await?;
1482
1483    let leader_id = request.variant.as_ref().and_then(|variant| match variant {
1484        proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
1485        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
1486        proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
1487    });
1488    for follower_id in &request.follower_ids {
1489        let follower_id = ConnectionId(*follower_id);
1490        if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id {
1491            session
1492                .peer
1493                .forward_send(session.connection_id, follower_id, request.clone())?;
1494        }
1495    }
1496    Ok(())
1497}
1498
1499async fn get_users(
1500    request: proto::GetUsers,
1501    response: Response<proto::GetUsers>,
1502    session: Session,
1503) -> Result<()> {
1504    let user_ids = request
1505        .user_ids
1506        .into_iter()
1507        .map(UserId::from_proto)
1508        .collect();
1509    let users = session
1510        .db()
1511        .await
1512        .get_users_by_ids(user_ids)
1513        .await?
1514        .into_iter()
1515        .map(|user| proto::User {
1516            id: user.id.to_proto(),
1517            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1518            github_login: user.github_login,
1519        })
1520        .collect();
1521    response.send(proto::UsersResponse { users })?;
1522    Ok(())
1523}
1524
1525async fn fuzzy_search_users(
1526    request: proto::FuzzySearchUsers,
1527    response: Response<proto::FuzzySearchUsers>,
1528    session: Session,
1529) -> Result<()> {
1530    let query = request.query;
1531    let users = match query.len() {
1532        0 => vec![],
1533        1 | 2 => session
1534            .db()
1535            .await
1536            .get_user_by_github_account(&query, None)
1537            .await?
1538            .into_iter()
1539            .collect(),
1540        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
1541    };
1542    let users = users
1543        .into_iter()
1544        .filter(|user| user.id != session.user_id)
1545        .map(|user| proto::User {
1546            id: user.id.to_proto(),
1547            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
1548            github_login: user.github_login,
1549        })
1550        .collect();
1551    response.send(proto::UsersResponse { users })?;
1552    Ok(())
1553}
1554
1555async fn request_contact(
1556    request: proto::RequestContact,
1557    response: Response<proto::RequestContact>,
1558    session: Session,
1559) -> Result<()> {
1560    let requester_id = session.user_id;
1561    let responder_id = UserId::from_proto(request.responder_id);
1562    if requester_id == responder_id {
1563        return Err(anyhow!("cannot add yourself as a contact"))?;
1564    }
1565
1566    session
1567        .db()
1568        .await
1569        .send_contact_request(requester_id, responder_id)
1570        .await?;
1571
1572    // Update outgoing contact requests of requester
1573    let mut update = proto::UpdateContacts::default();
1574    update.outgoing_requests.push(responder_id.to_proto());
1575    for connection_id in session
1576        .connection_pool()
1577        .await
1578        .user_connection_ids(requester_id)
1579    {
1580        session.peer.send(connection_id, update.clone())?;
1581    }
1582
1583    // Update incoming contact requests of responder
1584    let mut update = proto::UpdateContacts::default();
1585    update
1586        .incoming_requests
1587        .push(proto::IncomingContactRequest {
1588            requester_id: requester_id.to_proto(),
1589            should_notify: true,
1590        });
1591    for connection_id in session
1592        .connection_pool()
1593        .await
1594        .user_connection_ids(responder_id)
1595    {
1596        session.peer.send(connection_id, update.clone())?;
1597    }
1598
1599    response.send(proto::Ack {})?;
1600    Ok(())
1601}
1602
1603async fn respond_to_contact_request(
1604    request: proto::RespondToContactRequest,
1605    response: Response<proto::RespondToContactRequest>,
1606    session: Session,
1607) -> Result<()> {
1608    let responder_id = session.user_id;
1609    let requester_id = UserId::from_proto(request.requester_id);
1610    let db = session.db().await;
1611    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
1612        db.dismiss_contact_notification(responder_id, requester_id)
1613            .await?;
1614    } else {
1615        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
1616
1617        db.respond_to_contact_request(responder_id, requester_id, accept)
1618            .await?;
1619        let requester_busy = db.is_user_busy(requester_id).await?;
1620        let responder_busy = db.is_user_busy(responder_id).await?;
1621
1622        let pool = session.connection_pool().await;
1623        // Update responder with new contact
1624        let mut update = proto::UpdateContacts::default();
1625        if accept {
1626            update
1627                .contacts
1628                .push(contact_for_user(requester_id, false, requester_busy, &pool));
1629        }
1630        update
1631            .remove_incoming_requests
1632            .push(requester_id.to_proto());
1633        for connection_id in pool.user_connection_ids(responder_id) {
1634            session.peer.send(connection_id, update.clone())?;
1635        }
1636
1637        // Update requester with new contact
1638        let mut update = proto::UpdateContacts::default();
1639        if accept {
1640            update
1641                .contacts
1642                .push(contact_for_user(responder_id, true, responder_busy, &pool));
1643        }
1644        update
1645            .remove_outgoing_requests
1646            .push(responder_id.to_proto());
1647        for connection_id in pool.user_connection_ids(requester_id) {
1648            session.peer.send(connection_id, update.clone())?;
1649        }
1650    }
1651
1652    response.send(proto::Ack {})?;
1653    Ok(())
1654}
1655
1656async fn remove_contact(
1657    request: proto::RemoveContact,
1658    response: Response<proto::RemoveContact>,
1659    session: Session,
1660) -> Result<()> {
1661    let requester_id = session.user_id;
1662    let responder_id = UserId::from_proto(request.user_id);
1663    let db = session.db().await;
1664    db.remove_contact(requester_id, responder_id).await?;
1665
1666    let pool = session.connection_pool().await;
1667    // Update outgoing contact requests of requester
1668    let mut update = proto::UpdateContacts::default();
1669    update
1670        .remove_outgoing_requests
1671        .push(responder_id.to_proto());
1672    for connection_id in pool.user_connection_ids(requester_id) {
1673        session.peer.send(connection_id, update.clone())?;
1674    }
1675
1676    // Update incoming contact requests of responder
1677    let mut update = proto::UpdateContacts::default();
1678    update
1679        .remove_incoming_requests
1680        .push(requester_id.to_proto());
1681    for connection_id in pool.user_connection_ids(responder_id) {
1682        session.peer.send(connection_id, update.clone())?;
1683    }
1684
1685    response.send(proto::Ack {})?;
1686    Ok(())
1687}
1688
1689async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
1690    let project_id = ProjectId::from_proto(request.project_id);
1691    let project_connection_ids = session
1692        .db()
1693        .await
1694        .project_connection_ids(project_id, session.connection_id)
1695        .await?;
1696    broadcast(
1697        session.connection_id,
1698        project_connection_ids.iter().copied(),
1699        |connection_id| {
1700            session
1701                .peer
1702                .forward_send(session.connection_id, connection_id, request.clone())
1703        },
1704    );
1705    Ok(())
1706}
1707
1708async fn get_private_user_info(
1709    _request: proto::GetPrivateUserInfo,
1710    response: Response<proto::GetPrivateUserInfo>,
1711    session: Session,
1712) -> Result<()> {
1713    let metrics_id = session
1714        .db()
1715        .await
1716        .get_user_metrics_id(session.user_id)
1717        .await?;
1718    let user = session
1719        .db()
1720        .await
1721        .get_user_by_id(session.user_id)
1722        .await?
1723        .ok_or_else(|| anyhow!("user not found"))?;
1724    response.send(proto::GetPrivateUserInfoResponse {
1725        metrics_id,
1726        staff: user.admin,
1727    })?;
1728    Ok(())
1729}
1730
1731fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
1732    match message {
1733        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
1734        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
1735        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
1736        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
1737        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
1738            code: frame.code.into(),
1739            reason: frame.reason,
1740        })),
1741    }
1742}
1743
1744fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
1745    match message {
1746        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
1747        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
1748        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
1749        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
1750        AxumMessage::Close(frame) => {
1751            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
1752                code: frame.code.into(),
1753                reason: frame.reason,
1754            }))
1755        }
1756    }
1757}
1758
1759fn build_initial_contacts_update(
1760    contacts: Vec<db::Contact>,
1761    pool: &ConnectionPool,
1762) -> proto::UpdateContacts {
1763    let mut update = proto::UpdateContacts::default();
1764
1765    for contact in contacts {
1766        match contact {
1767            db::Contact::Accepted {
1768                user_id,
1769                should_notify,
1770                busy,
1771            } => {
1772                update
1773                    .contacts
1774                    .push(contact_for_user(user_id, should_notify, busy, &pool));
1775            }
1776            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
1777            db::Contact::Incoming {
1778                user_id,
1779                should_notify,
1780            } => update
1781                .incoming_requests
1782                .push(proto::IncomingContactRequest {
1783                    requester_id: user_id.to_proto(),
1784                    should_notify,
1785                }),
1786        }
1787    }
1788
1789    update
1790}
1791
1792fn contact_for_user(
1793    user_id: UserId,
1794    should_notify: bool,
1795    busy: bool,
1796    pool: &ConnectionPool,
1797) -> proto::Contact {
1798    proto::Contact {
1799        user_id: user_id.to_proto(),
1800        online: pool.is_user_online(user_id),
1801        busy,
1802        should_notify,
1803    }
1804}
1805
1806fn room_updated(room: &proto::Room, session: &Session) {
1807    for participant in &room.participants {
1808        session
1809            .peer
1810            .send(
1811                ConnectionId(participant.peer_id),
1812                proto::RoomUpdated {
1813                    room: Some(room.clone()),
1814                },
1815            )
1816            .trace_err();
1817    }
1818}
1819
1820async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
1821    let db = session.db().await;
1822    let contacts = db.get_contacts(user_id).await?;
1823    let busy = db.is_user_busy(user_id).await?;
1824
1825    let pool = session.connection_pool().await;
1826    let updated_contact = contact_for_user(user_id, false, busy, &pool);
1827    for contact in contacts {
1828        if let db::Contact::Accepted {
1829            user_id: contact_user_id,
1830            ..
1831        } = contact
1832        {
1833            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
1834                session
1835                    .peer
1836                    .send(
1837                        contact_conn_id,
1838                        proto::UpdateContacts {
1839                            contacts: vec![updated_contact.clone()],
1840                            remove_contacts: Default::default(),
1841                            incoming_requests: Default::default(),
1842                            remove_incoming_requests: Default::default(),
1843                            outgoing_requests: Default::default(),
1844                            remove_outgoing_requests: Default::default(),
1845                        },
1846                    )
1847                    .trace_err();
1848            }
1849        }
1850    }
1851    Ok(())
1852}
1853
1854async fn leave_room_for_session(session: &Session) -> Result<()> {
1855    let mut contacts_to_update = HashSet::default();
1856
1857    let canceled_calls_to_user_ids;
1858    let live_kit_room;
1859    let delete_live_kit_room;
1860    {
1861        let mut left_room = session.db().await.leave_room(session.connection_id).await?;
1862        contacts_to_update.insert(session.user_id);
1863
1864        for project in left_room.left_projects.values() {
1865            for connection_id in &project.connection_ids {
1866                if project.host_user_id == session.user_id {
1867                    session
1868                        .peer
1869                        .send(
1870                            *connection_id,
1871                            proto::UnshareProject {
1872                                project_id: project.id.to_proto(),
1873                            },
1874                        )
1875                        .trace_err();
1876                } else {
1877                    session
1878                        .peer
1879                        .send(
1880                            *connection_id,
1881                            proto::RemoveProjectCollaborator {
1882                                project_id: project.id.to_proto(),
1883                                peer_id: session.connection_id.0,
1884                            },
1885                        )
1886                        .trace_err();
1887                }
1888            }
1889
1890            session
1891                .peer
1892                .send(
1893                    session.connection_id,
1894                    proto::UnshareProject {
1895                        project_id: project.id.to_proto(),
1896                    },
1897                )
1898                .trace_err();
1899        }
1900
1901        room_updated(&left_room.room, &session);
1902        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
1903        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
1904        delete_live_kit_room = left_room.room.participants.is_empty();
1905    }
1906
1907    {
1908        let pool = session.connection_pool().await;
1909        for canceled_user_id in canceled_calls_to_user_ids {
1910            for connection_id in pool.user_connection_ids(canceled_user_id) {
1911                session
1912                    .peer
1913                    .send(connection_id, proto::CallCanceled {})
1914                    .trace_err();
1915            }
1916            contacts_to_update.insert(canceled_user_id);
1917        }
1918    }
1919
1920    for contact_user_id in contacts_to_update {
1921        update_user_contacts(contact_user_id, &session).await?;
1922    }
1923
1924    if let Some(live_kit) = session.live_kit_client.as_ref() {
1925        live_kit
1926            .remove_participant(live_kit_room.clone(), session.connection_id.to_string())
1927            .await
1928            .trace_err();
1929
1930        if delete_live_kit_room {
1931            live_kit.delete_room(live_kit_room).await.trace_err();
1932        }
1933    }
1934
1935    Ok(())
1936}
1937
1938pub trait ResultExt {
1939    type Ok;
1940
1941    fn trace_err(self) -> Option<Self::Ok>;
1942}
1943
1944impl<T, E> ResultExt for Result<T, E>
1945where
1946    E: std::fmt::Debug,
1947{
1948    type Ok = T;
1949
1950    fn trace_err(self) -> Option<T> {
1951        match self {
1952            Ok(value) => Some(value),
1953            Err(error) => {
1954                tracing::error!("{:?}", error);
1955                None
1956            }
1957        }
1958    }
1959}