rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, InviteMemberResult, MembershipUpdated,
   8        MessageId, NotificationId, PrincipalId, Project, ProjectId, RejoinedProject,
   9        RemoteProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35
  36use futures::{
  37    channel::oneshot,
  38    future::{self, BoxFuture},
  39    stream::FuturesUnordered,
  40    FutureExt, SinkExt, StreamExt, TryStreamExt,
  41};
  42use prometheus::{register_int_gauge, IntGauge};
  43use rpc::{
  44    proto::{
  45        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  46        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  47    },
  48    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  49};
  50use semantic_version::SemanticVersion;
  51use serde::{Serialize, Serializer};
  52use std::{
  53    any::TypeId,
  54    future::Future,
  55    marker::PhantomData,
  56    mem,
  57    net::SocketAddr,
  58    ops::{Deref, DerefMut},
  59    rc::Rc,
  60    sync::{
  61        atomic::{AtomicBool, Ordering::SeqCst},
  62        Arc, OnceLock,
  63    },
  64    time::{Duration, Instant},
  65};
  66use time::OffsetDateTime;
  67use tokio::sync::{watch, Semaphore};
  68use tower::ServiceBuilder;
  69use tracing::{
  70    field::{self},
  71    info_span, instrument, Instrument,
  72};
  73use util::http::IsahcHttpClient;
  74
  75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  76
  77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  79
  80const MESSAGE_COUNT_PER_PAGE: usize = 100;
  81const MAX_MESSAGE_LEN: usize = 1024;
  82const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  83
  84type MessageHandler =
  85    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  86
  87struct Response<R> {
  88    peer: Arc<Peer>,
  89    receipt: Receipt<R>,
  90    responded: Arc<AtomicBool>,
  91}
  92
  93impl<R: RequestMessage> Response<R> {
  94    fn send(self, payload: R::Response) -> Result<()> {
  95        self.responded.store(true, SeqCst);
  96        self.peer.respond(self.receipt, payload)?;
  97        Ok(())
  98    }
  99}
 100
 101struct StreamingResponse<R: RequestMessage> {
 102    peer: Arc<Peer>,
 103    receipt: Receipt<R>,
 104}
 105
 106impl<R: RequestMessage> StreamingResponse<R> {
 107    fn send(&self, payload: R::Response) -> Result<()> {
 108        self.peer.respond(self.receipt, payload)?;
 109        Ok(())
 110    }
 111}
 112
 113#[derive(Clone, Debug)]
 114pub enum Principal {
 115    User(User),
 116    Impersonated { user: User, admin: User },
 117    DevServer(dev_server::Model),
 118}
 119
 120impl Principal {
 121    fn update_span(&self, span: &tracing::Span) {
 122        match &self {
 123            Principal::User(user) => {
 124                span.record("user_id", &user.id.0);
 125                span.record("login", &user.github_login);
 126            }
 127            Principal::Impersonated { user, admin } => {
 128                span.record("user_id", &user.id.0);
 129                span.record("login", &user.github_login);
 130                span.record("impersonator", &admin.github_login);
 131            }
 132            Principal::DevServer(dev_server) => {
 133                span.record("dev_server_id", &dev_server.id.0);
 134            }
 135        }
 136    }
 137}
 138
 139#[derive(Clone)]
 140struct Session {
 141    principal: Principal,
 142    connection_id: ConnectionId,
 143    db: Arc<tokio::sync::Mutex<DbHandle>>,
 144    peer: Arc<Peer>,
 145    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 146    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 147    http_client: IsahcHttpClient,
 148    rate_limiter: Arc<RateLimiter>,
 149    _executor: Executor,
 150}
 151
 152impl Session {
 153    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 154        #[cfg(test)]
 155        tokio::task::yield_now().await;
 156        let guard = self.db.lock().await;
 157        #[cfg(test)]
 158        tokio::task::yield_now().await;
 159        guard
 160    }
 161
 162    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        let guard = self.connection_pool.lock();
 166        ConnectionPoolGuard {
 167            guard,
 168            _not_send: PhantomData,
 169        }
 170    }
 171
 172    fn for_user(self) -> Option<UserSession> {
 173        UserSession::new(self)
 174    }
 175
 176    fn for_dev_server(self) -> Option<DevServerSession> {
 177        DevServerSession::new(self)
 178    }
 179
 180    fn user_id(&self) -> Option<UserId> {
 181        match &self.principal {
 182            Principal::User(user) => Some(user.id),
 183            Principal::Impersonated { user, .. } => Some(user.id),
 184            Principal::DevServer(_) => None,
 185        }
 186    }
 187
 188    fn dev_server_id(&self) -> Option<DevServerId> {
 189        match &self.principal {
 190            Principal::User(_) | Principal::Impersonated { .. } => None,
 191            Principal::DevServer(dev_server) => Some(dev_server.id),
 192        }
 193    }
 194
 195    fn principal_id(&self) -> PrincipalId {
 196        match &self.principal {
 197            Principal::User(user) => PrincipalId::UserId(user.id),
 198            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 199            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 200        }
 201    }
 202}
 203
 204impl Debug for Session {
 205    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 206        let mut result = f.debug_struct("Session");
 207        match &self.principal {
 208            Principal::User(user) => {
 209                result.field("user", &user.github_login);
 210            }
 211            Principal::Impersonated { user, admin } => {
 212                result.field("user", &user.github_login);
 213                result.field("impersonator", &admin.github_login);
 214            }
 215            Principal::DevServer(dev_server) => {
 216                result.field("dev_server", &dev_server.id);
 217            }
 218        }
 219        result.field("connection_id", &self.connection_id).finish()
 220    }
 221}
 222
 223struct UserSession(Session);
 224
 225impl UserSession {
 226    pub fn new(s: Session) -> Option<Self> {
 227        s.user_id().map(|_| UserSession(s))
 228    }
 229    pub fn user_id(&self) -> UserId {
 230        self.0.user_id().unwrap()
 231    }
 232}
 233
 234impl Deref for UserSession {
 235    type Target = Session;
 236
 237    fn deref(&self) -> &Self::Target {
 238        &self.0
 239    }
 240}
 241impl DerefMut for UserSession {
 242    fn deref_mut(&mut self) -> &mut Self::Target {
 243        &mut self.0
 244    }
 245}
 246
 247struct DevServerSession(Session);
 248
 249impl DevServerSession {
 250    pub fn new(s: Session) -> Option<Self> {
 251        s.dev_server_id().map(|_| DevServerSession(s))
 252    }
 253    pub fn dev_server_id(&self) -> DevServerId {
 254        self.0.dev_server_id().unwrap()
 255    }
 256}
 257
 258impl Deref for DevServerSession {
 259    type Target = Session;
 260
 261    fn deref(&self) -> &Self::Target {
 262        &self.0
 263    }
 264}
 265impl DerefMut for DevServerSession {
 266    fn deref_mut(&mut self) -> &mut Self::Target {
 267        &mut self.0
 268    }
 269}
 270
 271fn user_handler<M: RequestMessage, Fut>(
 272    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 273) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 274where
 275    Fut: Send + Future<Output = Result<()>>,
 276{
 277    let handler = Arc::new(handler);
 278    move |message, response, session| {
 279        let handler = handler.clone();
 280        Box::pin(async move {
 281            if let Some(user_session) = session.for_user() {
 282                Ok(handler(message, response, user_session).await?)
 283            } else {
 284                Err(Error::Internal(anyhow!(
 285                    "must be a user to call {}",
 286                    M::NAME
 287                )))
 288            }
 289        })
 290    }
 291}
 292
 293fn dev_server_handler<M: RequestMessage, Fut>(
 294    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 295) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 296where
 297    Fut: Send + Future<Output = Result<()>>,
 298{
 299    let handler = Arc::new(handler);
 300    move |message, response, session| {
 301        let handler = handler.clone();
 302        Box::pin(async move {
 303            if let Some(dev_server_session) = session.for_dev_server() {
 304                Ok(handler(message, response, dev_server_session).await?)
 305            } else {
 306                Err(Error::Internal(anyhow!(
 307                    "must be a dev server to call {}",
 308                    M::NAME
 309                )))
 310            }
 311        })
 312    }
 313}
 314
 315fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 316    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 317) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 318where
 319    InnertRetFut: Send + Future<Output = Result<()>>,
 320{
 321    let handler = Arc::new(handler);
 322    move |message, session| {
 323        let handler = handler.clone();
 324        Box::pin(async move {
 325            if let Some(user_session) = session.for_user() {
 326                Ok(handler(message, user_session).await?)
 327            } else {
 328                Err(Error::Internal(anyhow!(
 329                    "must be a user to call {}",
 330                    M::NAME
 331                )))
 332            }
 333        })
 334    }
 335}
 336
 337struct DbHandle(Arc<Database>);
 338
 339impl Deref for DbHandle {
 340    type Target = Database;
 341
 342    fn deref(&self) -> &Self::Target {
 343        self.0.as_ref()
 344    }
 345}
 346
 347pub struct Server {
 348    id: parking_lot::Mutex<ServerId>,
 349    peer: Arc<Peer>,
 350    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 351    app_state: Arc<AppState>,
 352    handlers: HashMap<TypeId, MessageHandler>,
 353    teardown: watch::Sender<bool>,
 354}
 355
 356pub(crate) struct ConnectionPoolGuard<'a> {
 357    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 358    _not_send: PhantomData<Rc<()>>,
 359}
 360
 361#[derive(Serialize)]
 362pub struct ServerSnapshot<'a> {
 363    peer: &'a Peer,
 364    #[serde(serialize_with = "serialize_deref")]
 365    connection_pool: ConnectionPoolGuard<'a>,
 366}
 367
 368pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 369where
 370    S: Serializer,
 371    T: Deref<Target = U>,
 372    U: Serialize,
 373{
 374    Serialize::serialize(value.deref(), serializer)
 375}
 376
 377impl Server {
 378    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 379        let mut server = Self {
 380            id: parking_lot::Mutex::new(id),
 381            peer: Peer::new(id.0 as u32),
 382            app_state: app_state.clone(),
 383            connection_pool: Default::default(),
 384            handlers: Default::default(),
 385            teardown: watch::channel(false).0,
 386        };
 387
 388        server
 389            .add_request_handler(ping)
 390            .add_request_handler(user_handler(create_room))
 391            .add_request_handler(user_handler(join_room))
 392            .add_request_handler(user_handler(rejoin_room))
 393            .add_request_handler(user_handler(leave_room))
 394            .add_request_handler(user_handler(set_room_participant_role))
 395            .add_request_handler(user_handler(call))
 396            .add_request_handler(user_handler(cancel_call))
 397            .add_message_handler(user_message_handler(decline_call))
 398            .add_request_handler(user_handler(update_participant_location))
 399            .add_request_handler(user_handler(share_project))
 400            .add_message_handler(unshare_project)
 401            .add_request_handler(user_handler(join_project))
 402            .add_request_handler(user_handler(join_hosted_project))
 403            .add_request_handler(user_handler(rejoin_remote_projects))
 404            .add_request_handler(user_handler(create_remote_project))
 405            .add_request_handler(user_handler(create_dev_server))
 406            .add_request_handler(dev_server_handler(share_remote_project))
 407            .add_request_handler(dev_server_handler(shutdown_dev_server))
 408            .add_request_handler(dev_server_handler(reconnect_dev_server))
 409            .add_message_handler(user_message_handler(leave_project))
 410            .add_request_handler(update_project)
 411            .add_request_handler(update_worktree)
 412            .add_message_handler(start_language_server)
 413            .add_message_handler(update_language_server)
 414            .add_message_handler(update_diagnostic_summary)
 415            .add_message_handler(update_worktree_settings)
 416            .add_request_handler(user_handler(
 417                forward_read_only_project_request::<proto::GetHover>,
 418            ))
 419            .add_request_handler(user_handler(
 420                forward_read_only_project_request::<proto::GetDefinition>,
 421            ))
 422            .add_request_handler(user_handler(
 423                forward_read_only_project_request::<proto::GetTypeDefinition>,
 424            ))
 425            .add_request_handler(user_handler(
 426                forward_read_only_project_request::<proto::GetReferences>,
 427            ))
 428            .add_request_handler(user_handler(
 429                forward_read_only_project_request::<proto::SearchProject>,
 430            ))
 431            .add_request_handler(user_handler(
 432                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 433            ))
 434            .add_request_handler(user_handler(
 435                forward_read_only_project_request::<proto::GetProjectSymbols>,
 436            ))
 437            .add_request_handler(user_handler(
 438                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 439            ))
 440            .add_request_handler(user_handler(
 441                forward_read_only_project_request::<proto::OpenBufferById>,
 442            ))
 443            .add_request_handler(user_handler(
 444                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 445            ))
 446            .add_request_handler(user_handler(
 447                forward_read_only_project_request::<proto::InlayHints>,
 448            ))
 449            .add_request_handler(user_handler(
 450                forward_read_only_project_request::<proto::OpenBufferByPath>,
 451            ))
 452            .add_request_handler(user_handler(
 453                forward_mutating_project_request::<proto::GetCompletions>,
 454            ))
 455            .add_request_handler(user_handler(
 456                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 457            ))
 458            .add_request_handler(user_handler(
 459                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 460            ))
 461            .add_request_handler(user_handler(
 462                forward_mutating_project_request::<proto::GetCodeActions>,
 463            ))
 464            .add_request_handler(user_handler(
 465                forward_mutating_project_request::<proto::ApplyCodeAction>,
 466            ))
 467            .add_request_handler(user_handler(
 468                forward_mutating_project_request::<proto::PrepareRename>,
 469            ))
 470            .add_request_handler(user_handler(
 471                forward_mutating_project_request::<proto::PerformRename>,
 472            ))
 473            .add_request_handler(user_handler(
 474                forward_mutating_project_request::<proto::ReloadBuffers>,
 475            ))
 476            .add_request_handler(user_handler(
 477                forward_mutating_project_request::<proto::FormatBuffers>,
 478            ))
 479            .add_request_handler(user_handler(
 480                forward_mutating_project_request::<proto::CreateProjectEntry>,
 481            ))
 482            .add_request_handler(user_handler(
 483                forward_mutating_project_request::<proto::RenameProjectEntry>,
 484            ))
 485            .add_request_handler(user_handler(
 486                forward_mutating_project_request::<proto::CopyProjectEntry>,
 487            ))
 488            .add_request_handler(user_handler(
 489                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 490            ))
 491            .add_request_handler(user_handler(
 492                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 493            ))
 494            .add_request_handler(user_handler(
 495                forward_mutating_project_request::<proto::OnTypeFormatting>,
 496            ))
 497            .add_request_handler(user_handler(
 498                forward_mutating_project_request::<proto::SaveBuffer>,
 499            ))
 500            .add_request_handler(user_handler(
 501                forward_mutating_project_request::<proto::BlameBuffer>,
 502            ))
 503            .add_request_handler(user_handler(
 504                forward_mutating_project_request::<proto::MultiLspQuery>,
 505            ))
 506            .add_message_handler(create_buffer_for_peer)
 507            .add_request_handler(update_buffer)
 508            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 509            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 510            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 511            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 512            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 513            .add_request_handler(get_users)
 514            .add_request_handler(user_handler(fuzzy_search_users))
 515            .add_request_handler(user_handler(request_contact))
 516            .add_request_handler(user_handler(remove_contact))
 517            .add_request_handler(user_handler(respond_to_contact_request))
 518            .add_request_handler(user_handler(create_channel))
 519            .add_request_handler(user_handler(delete_channel))
 520            .add_request_handler(user_handler(invite_channel_member))
 521            .add_request_handler(user_handler(remove_channel_member))
 522            .add_request_handler(user_handler(set_channel_member_role))
 523            .add_request_handler(user_handler(set_channel_visibility))
 524            .add_request_handler(user_handler(rename_channel))
 525            .add_request_handler(user_handler(join_channel_buffer))
 526            .add_request_handler(user_handler(leave_channel_buffer))
 527            .add_message_handler(user_message_handler(update_channel_buffer))
 528            .add_request_handler(user_handler(rejoin_channel_buffers))
 529            .add_request_handler(user_handler(get_channel_members))
 530            .add_request_handler(user_handler(respond_to_channel_invite))
 531            .add_request_handler(user_handler(join_channel))
 532            .add_request_handler(user_handler(join_channel_chat))
 533            .add_message_handler(user_message_handler(leave_channel_chat))
 534            .add_request_handler(user_handler(send_channel_message))
 535            .add_request_handler(user_handler(remove_channel_message))
 536            .add_request_handler(user_handler(update_channel_message))
 537            .add_request_handler(user_handler(get_channel_messages))
 538            .add_request_handler(user_handler(get_channel_messages_by_id))
 539            .add_request_handler(user_handler(get_notifications))
 540            .add_request_handler(user_handler(mark_notification_as_read))
 541            .add_request_handler(user_handler(move_channel))
 542            .add_request_handler(user_handler(follow))
 543            .add_message_handler(user_message_handler(unfollow))
 544            .add_message_handler(user_message_handler(update_followers))
 545            .add_request_handler(user_handler(get_private_user_info))
 546            .add_message_handler(user_message_handler(acknowledge_channel_message))
 547            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 548            .add_streaming_request_handler({
 549                let app_state = app_state.clone();
 550                move |request, response, session| {
 551                    complete_with_language_model(
 552                        request,
 553                        response,
 554                        session,
 555                        app_state.config.openai_api_key.clone(),
 556                        app_state.config.google_ai_api_key.clone(),
 557                        app_state.config.anthropic_api_key.clone(),
 558                    )
 559                }
 560            })
 561            .add_request_handler({
 562                let app_state = app_state.clone();
 563                user_handler(move |request, response, session| {
 564                    count_tokens_with_language_model(
 565                        request,
 566                        response,
 567                        session,
 568                        app_state.config.google_ai_api_key.clone(),
 569                    )
 570                })
 571            });
 572
 573        Arc::new(server)
 574    }
 575
 576    pub async fn start(&self) -> Result<()> {
 577        let server_id = *self.id.lock();
 578        let app_state = self.app_state.clone();
 579        let peer = self.peer.clone();
 580        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 581        let pool = self.connection_pool.clone();
 582        let live_kit_client = self.app_state.live_kit_client.clone();
 583
 584        let span = info_span!("start server");
 585        self.app_state.executor.spawn_detached(
 586            async move {
 587                tracing::info!("waiting for cleanup timeout");
 588                timeout.await;
 589                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 590                if let Some((room_ids, channel_ids)) = app_state
 591                    .db
 592                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 593                    .await
 594                    .trace_err()
 595                {
 596                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 597                    tracing::info!(
 598                        stale_channel_buffer_count = channel_ids.len(),
 599                        "retrieved stale channel buffers"
 600                    );
 601
 602                    for channel_id in channel_ids {
 603                        if let Some(refreshed_channel_buffer) = app_state
 604                            .db
 605                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 606                            .await
 607                            .trace_err()
 608                        {
 609                            for connection_id in refreshed_channel_buffer.connection_ids {
 610                                peer.send(
 611                                    connection_id,
 612                                    proto::UpdateChannelBufferCollaborators {
 613                                        channel_id: channel_id.to_proto(),
 614                                        collaborators: refreshed_channel_buffer
 615                                            .collaborators
 616                                            .clone(),
 617                                    },
 618                                )
 619                                .trace_err();
 620                            }
 621                        }
 622                    }
 623
 624                    for room_id in room_ids {
 625                        let mut contacts_to_update = HashSet::default();
 626                        let mut canceled_calls_to_user_ids = Vec::new();
 627                        let mut live_kit_room = String::new();
 628                        let mut delete_live_kit_room = false;
 629
 630                        if let Some(mut refreshed_room) = app_state
 631                            .db
 632                            .clear_stale_room_participants(room_id, server_id)
 633                            .await
 634                            .trace_err()
 635                        {
 636                            tracing::info!(
 637                                room_id = room_id.0,
 638                                new_participant_count = refreshed_room.room.participants.len(),
 639                                "refreshed room"
 640                            );
 641                            room_updated(&refreshed_room.room, &peer);
 642                            if let Some(channel) = refreshed_room.channel.as_ref() {
 643                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 644                            }
 645                            contacts_to_update
 646                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 647                            contacts_to_update
 648                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 649                            canceled_calls_to_user_ids =
 650                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 651                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 652                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 653                        }
 654
 655                        {
 656                            let pool = pool.lock();
 657                            for canceled_user_id in canceled_calls_to_user_ids {
 658                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 659                                    peer.send(
 660                                        connection_id,
 661                                        proto::CallCanceled {
 662                                            room_id: room_id.to_proto(),
 663                                        },
 664                                    )
 665                                    .trace_err();
 666                                }
 667                            }
 668                        }
 669
 670                        for user_id in contacts_to_update {
 671                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 672                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 673                            if let Some((busy, contacts)) = busy.zip(contacts) {
 674                                let pool = pool.lock();
 675                                let updated_contact = contact_for_user(user_id, busy, &pool);
 676                                for contact in contacts {
 677                                    if let db::Contact::Accepted {
 678                                        user_id: contact_user_id,
 679                                        ..
 680                                    } = contact
 681                                    {
 682                                        for contact_conn_id in
 683                                            pool.user_connection_ids(contact_user_id)
 684                                        {
 685                                            peer.send(
 686                                                contact_conn_id,
 687                                                proto::UpdateContacts {
 688                                                    contacts: vec![updated_contact.clone()],
 689                                                    remove_contacts: Default::default(),
 690                                                    incoming_requests: Default::default(),
 691                                                    remove_incoming_requests: Default::default(),
 692                                                    outgoing_requests: Default::default(),
 693                                                    remove_outgoing_requests: Default::default(),
 694                                                },
 695                                            )
 696                                            .trace_err();
 697                                        }
 698                                    }
 699                                }
 700                            }
 701                        }
 702
 703                        if let Some(live_kit) = live_kit_client.as_ref() {
 704                            if delete_live_kit_room {
 705                                live_kit.delete_room(live_kit_room).await.trace_err();
 706                            }
 707                        }
 708                    }
 709                }
 710
 711                app_state
 712                    .db
 713                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 714                    .await
 715                    .trace_err();
 716            }
 717            .instrument(span),
 718        );
 719        Ok(())
 720    }
 721
 722    pub fn teardown(&self) {
 723        self.peer.teardown();
 724        self.connection_pool.lock().reset();
 725        let _ = self.teardown.send(true);
 726    }
 727
 728    #[cfg(test)]
 729    pub fn reset(&self, id: ServerId) {
 730        self.teardown();
 731        *self.id.lock() = id;
 732        self.peer.reset(id.0 as u32);
 733        let _ = self.teardown.send(false);
 734    }
 735
 736    #[cfg(test)]
 737    pub fn id(&self) -> ServerId {
 738        *self.id.lock()
 739    }
 740
 741    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 742    where
 743        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 744        Fut: 'static + Send + Future<Output = Result<()>>,
 745        M: EnvelopedMessage,
 746    {
 747        let prev_handler = self.handlers.insert(
 748            TypeId::of::<M>(),
 749            Box::new(move |envelope, session| {
 750                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 751                let received_at = envelope.received_at;
 752                    tracing::info!(
 753                        "message received"
 754                    );
 755                let start_time = Instant::now();
 756                let future = (handler)(*envelope, session);
 757                async move {
 758                    let result = future.await;
 759                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 760                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 761                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 762                    let payload_type = M::NAME;
 763                    match result {
 764                        Err(error) => {
 765                            // todo!(), why isn't this logged inside the span?
 766                            tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, payload_type, "error handling message")
 767                        }
 768                        Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
 769                    }
 770                }
 771                .boxed()
 772            }),
 773        );
 774        if prev_handler.is_some() {
 775            panic!("registered a handler for the same message twice");
 776        }
 777        self
 778    }
 779
 780    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 781    where
 782        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 783        Fut: 'static + Send + Future<Output = Result<()>>,
 784        M: EnvelopedMessage,
 785    {
 786        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 787        self
 788    }
 789
 790    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 791    where
 792        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 793        Fut: Send + Future<Output = Result<()>>,
 794        M: RequestMessage,
 795    {
 796        let handler = Arc::new(handler);
 797        self.add_handler(move |envelope, session| {
 798            let receipt = envelope.receipt();
 799            let handler = handler.clone();
 800            async move {
 801                let peer = session.peer.clone();
 802                let responded = Arc::new(AtomicBool::default());
 803                let response = Response {
 804                    peer: peer.clone(),
 805                    responded: responded.clone(),
 806                    receipt,
 807                };
 808                match (handler)(envelope.payload, response, session).await {
 809                    Ok(()) => {
 810                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 811                            Ok(())
 812                        } else {
 813                            Err(anyhow!("handler did not send a response"))?
 814                        }
 815                    }
 816                    Err(error) => {
 817                        let proto_err = match &error {
 818                            Error::Internal(err) => err.to_proto(),
 819                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 820                        };
 821                        peer.respond_with_error(receipt, proto_err)?;
 822                        Err(error)
 823                    }
 824                }
 825            }
 826        })
 827    }
 828
 829    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 830    where
 831        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 832        Fut: Send + Future<Output = Result<()>>,
 833        M: RequestMessage,
 834    {
 835        let handler = Arc::new(handler);
 836        self.add_handler(move |envelope, session| {
 837            let receipt = envelope.receipt();
 838            let handler = handler.clone();
 839            async move {
 840                let peer = session.peer.clone();
 841                let response = StreamingResponse {
 842                    peer: peer.clone(),
 843                    receipt,
 844                };
 845                match (handler)(envelope.payload, response, session).await {
 846                    Ok(()) => {
 847                        peer.end_stream(receipt)?;
 848                        Ok(())
 849                    }
 850                    Err(error) => {
 851                        let proto_err = match &error {
 852                            Error::Internal(err) => err.to_proto(),
 853                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 854                        };
 855                        peer.respond_with_error(receipt, proto_err)?;
 856                        Err(error)
 857                    }
 858                }
 859            }
 860        })
 861    }
 862
 863    #[allow(clippy::too_many_arguments)]
 864    pub fn handle_connection(
 865        self: &Arc<Self>,
 866        connection: Connection,
 867        address: String,
 868        principal: Principal,
 869        zed_version: ZedVersion,
 870        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 871        executor: Executor,
 872    ) -> impl Future<Output = ()> {
 873        let this = self.clone();
 874        let span = info_span!("handle connection", %address,
 875            connection_id=field::Empty,
 876            user_id=field::Empty,
 877            login=field::Empty,
 878            impersonator=field::Empty,
 879            dev_server_id=field::Empty
 880        );
 881        principal.update_span(&span);
 882
 883        let mut teardown = self.teardown.subscribe();
 884        async move {
 885            if *teardown.borrow() {
 886                tracing::error!("server is tearing down");
 887                return
 888            }
 889            let (connection_id, handle_io, mut incoming_rx) = this
 890                .peer
 891                .add_connection(connection, {
 892                    let executor = executor.clone();
 893                    move |duration| executor.sleep(duration)
 894                });
 895            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 896            tracing::info!("connection opened");
 897
 898            let http_client = match IsahcHttpClient::new() {
 899                Ok(http_client) => http_client,
 900                Err(error) => {
 901                    tracing::error!(?error, "failed to create HTTP client");
 902                    return;
 903                }
 904            };
 905
 906            let session = Session {
 907                principal: principal.clone(),
 908                connection_id,
 909                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 910                peer: this.peer.clone(),
 911                connection_pool: this.connection_pool.clone(),
 912                live_kit_client: this.app_state.live_kit_client.clone(),
 913                http_client,
 914                rate_limiter: this.app_state.rate_limiter.clone(),
 915                _executor: executor.clone(),
 916            };
 917
 918            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 919                tracing::error!(?error, "failed to send initial client update");
 920                return;
 921            }
 922
 923            let handle_io = handle_io.fuse();
 924            futures::pin_mut!(handle_io);
 925
 926            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 927            // This prevents deadlocks when e.g., client A performs a request to client B and
 928            // client B performs a request to client A. If both clients stop processing further
 929            // messages until their respective request completes, they won't have a chance to
 930            // respond to the other client's request and cause a deadlock.
 931            //
 932            // This arrangement ensures we will attempt to process earlier messages first, but fall
 933            // back to processing messages arrived later in the spirit of making progress.
 934            let mut foreground_message_handlers = FuturesUnordered::new();
 935            let concurrent_handlers = Arc::new(Semaphore::new(256));
 936            loop {
 937                let next_message = async {
 938                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 939                    let message = incoming_rx.next().await;
 940                    (permit, message)
 941                }.fuse();
 942                futures::pin_mut!(next_message);
 943                futures::select_biased! {
 944                    _ = teardown.changed().fuse() => return,
 945                    result = handle_io => {
 946                        if let Err(error) = result {
 947                            tracing::error!(?error, "error handling I/O");
 948                        }
 949                        break;
 950                    }
 951                    _ = foreground_message_handlers.next() => {}
 952                    next_message = next_message => {
 953                        let (permit, message) = next_message;
 954                        if let Some(message) = message {
 955                            let type_name = message.payload_type_name();
 956                            // note: we copy all the fields from the parent span so we can query them in the logs.
 957                            // (https://github.com/tokio-rs/tracing/issues/2670).
 958                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 959                                user_id=field::Empty,
 960                                login=field::Empty,
 961                                impersonator=field::Empty,
 962                                dev_server_id=field::Empty
 963                            );
 964                            principal.update_span(&span);
 965                            let span_enter = span.enter();
 966                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 967                                let is_background = message.is_background();
 968                                let handle_message = (handler)(message, session.clone());
 969                                drop(span_enter);
 970
 971                                let handle_message = async move {
 972                                    handle_message.await;
 973                                    drop(permit);
 974                                }.instrument(span);
 975                                if is_background {
 976                                    executor.spawn_detached(handle_message);
 977                                } else {
 978                                    foreground_message_handlers.push(handle_message);
 979                                }
 980                            } else {
 981                                tracing::error!("no message handler");
 982                            }
 983                        } else {
 984                            tracing::info!("connection closed");
 985                            break;
 986                        }
 987                    }
 988                }
 989            }
 990
 991            drop(foreground_message_handlers);
 992            tracing::info!("signing out");
 993            if let Err(error) = connection_lost(session, teardown, executor).await {
 994                tracing::error!(?error, "error signing out");
 995            }
 996
 997        }.instrument(span)
 998    }
 999
1000    async fn send_initial_client_update(
1001        &self,
1002        connection_id: ConnectionId,
1003        principal: &Principal,
1004        zed_version: ZedVersion,
1005        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1006        session: &Session,
1007    ) -> Result<()> {
1008        self.peer.send(
1009            connection_id,
1010            proto::Hello {
1011                peer_id: Some(connection_id.into()),
1012            },
1013        )?;
1014        tracing::info!("sent hello message");
1015        if let Some(send_connection_id) = send_connection_id.take() {
1016            let _ = send_connection_id.send(connection_id);
1017        }
1018
1019        match principal {
1020            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1021                if !user.connected_once {
1022                    self.peer.send(connection_id, proto::ShowContacts {})?;
1023                    self.app_state
1024                        .db
1025                        .set_user_connected_once(user.id, true)
1026                        .await?;
1027                }
1028
1029                let (contacts, channels_for_user, channel_invites) = future::try_join3(
1030                    self.app_state.db.get_contacts(user.id),
1031                    self.app_state.db.get_channels_for_user(user.id),
1032                    self.app_state.db.get_channel_invites_for_user(user.id),
1033                )
1034                .await?;
1035
1036                {
1037                    let mut pool = self.connection_pool.lock();
1038                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1039                    for membership in &channels_for_user.channel_memberships {
1040                        pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
1041                    }
1042                    self.peer.send(
1043                        connection_id,
1044                        build_initial_contacts_update(contacts, &pool),
1045                    )?;
1046                    self.peer.send(
1047                        connection_id,
1048                        build_update_user_channels(&channels_for_user),
1049                    )?;
1050                    self.peer.send(
1051                        connection_id,
1052                        build_channels_update(channels_for_user, channel_invites, &pool),
1053                    )?;
1054                }
1055
1056                if let Some(incoming_call) =
1057                    self.app_state.db.incoming_call_for_user(user.id).await?
1058                {
1059                    self.peer.send(connection_id, incoming_call)?;
1060                }
1061
1062                update_user_contacts(user.id, &session).await?;
1063            }
1064            Principal::DevServer(dev_server) => {
1065                {
1066                    let mut pool = self.connection_pool.lock();
1067                    if pool.dev_server_connection_id(dev_server.id).is_some() {
1068                        return Err(anyhow!(ErrorCode::DevServerAlreadyOnline))?;
1069                    };
1070                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1071                }
1072                update_dev_server_status(dev_server, proto::DevServerStatus::Online, &session)
1073                    .await;
1074                // todo!() allow only one connection.
1075
1076                let projects = self
1077                    .app_state
1078                    .db
1079                    .get_remote_projects_for_dev_server(dev_server.id)
1080                    .await?;
1081                self.peer
1082                    .send(connection_id, proto::DevServerInstructions { projects })?;
1083            }
1084        }
1085
1086        Ok(())
1087    }
1088
1089    pub async fn invite_code_redeemed(
1090        self: &Arc<Self>,
1091        inviter_id: UserId,
1092        invitee_id: UserId,
1093    ) -> Result<()> {
1094        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1095            if let Some(code) = &user.invite_code {
1096                let pool = self.connection_pool.lock();
1097                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1098                for connection_id in pool.user_connection_ids(inviter_id) {
1099                    self.peer.send(
1100                        connection_id,
1101                        proto::UpdateContacts {
1102                            contacts: vec![invitee_contact.clone()],
1103                            ..Default::default()
1104                        },
1105                    )?;
1106                    self.peer.send(
1107                        connection_id,
1108                        proto::UpdateInviteInfo {
1109                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1110                            count: user.invite_count as u32,
1111                        },
1112                    )?;
1113                }
1114            }
1115        }
1116        Ok(())
1117    }
1118
1119    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1120        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1121            if let Some(invite_code) = &user.invite_code {
1122                let pool = self.connection_pool.lock();
1123                for connection_id in pool.user_connection_ids(user_id) {
1124                    self.peer.send(
1125                        connection_id,
1126                        proto::UpdateInviteInfo {
1127                            url: format!(
1128                                "{}{}",
1129                                self.app_state.config.invite_link_prefix, invite_code
1130                            ),
1131                            count: user.invite_count as u32,
1132                        },
1133                    )?;
1134                }
1135            }
1136        }
1137        Ok(())
1138    }
1139
1140    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1141        ServerSnapshot {
1142            connection_pool: ConnectionPoolGuard {
1143                guard: self.connection_pool.lock(),
1144                _not_send: PhantomData,
1145            },
1146            peer: &self.peer,
1147        }
1148    }
1149}
1150
1151impl<'a> Deref for ConnectionPoolGuard<'a> {
1152    type Target = ConnectionPool;
1153
1154    fn deref(&self) -> &Self::Target {
1155        &self.guard
1156    }
1157}
1158
1159impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1160    fn deref_mut(&mut self) -> &mut Self::Target {
1161        &mut self.guard
1162    }
1163}
1164
1165impl<'a> Drop for ConnectionPoolGuard<'a> {
1166    fn drop(&mut self) {
1167        #[cfg(test)]
1168        self.check_invariants();
1169    }
1170}
1171
1172fn broadcast<F>(
1173    sender_id: Option<ConnectionId>,
1174    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1175    mut f: F,
1176) where
1177    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1178{
1179    for receiver_id in receiver_ids {
1180        if Some(receiver_id) != sender_id {
1181            if let Err(error) = f(receiver_id) {
1182                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1183            }
1184        }
1185    }
1186}
1187
1188pub struct ProtocolVersion(u32);
1189
1190impl Header for ProtocolVersion {
1191    fn name() -> &'static HeaderName {
1192        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1193        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1194    }
1195
1196    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1197    where
1198        Self: Sized,
1199        I: Iterator<Item = &'i axum::http::HeaderValue>,
1200    {
1201        let version = values
1202            .next()
1203            .ok_or_else(axum::headers::Error::invalid)?
1204            .to_str()
1205            .map_err(|_| axum::headers::Error::invalid())?
1206            .parse()
1207            .map_err(|_| axum::headers::Error::invalid())?;
1208        Ok(Self(version))
1209    }
1210
1211    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1212        values.extend([self.0.to_string().parse().unwrap()]);
1213    }
1214}
1215
1216pub struct AppVersionHeader(SemanticVersion);
1217impl Header for AppVersionHeader {
1218    fn name() -> &'static HeaderName {
1219        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1220        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1221    }
1222
1223    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1224    where
1225        Self: Sized,
1226        I: Iterator<Item = &'i axum::http::HeaderValue>,
1227    {
1228        let version = values
1229            .next()
1230            .ok_or_else(axum::headers::Error::invalid)?
1231            .to_str()
1232            .map_err(|_| axum::headers::Error::invalid())?
1233            .parse()
1234            .map_err(|_| axum::headers::Error::invalid())?;
1235        Ok(Self(version))
1236    }
1237
1238    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1239        values.extend([self.0.to_string().parse().unwrap()]);
1240    }
1241}
1242
1243pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1244    Router::new()
1245        .route("/rpc", get(handle_websocket_request))
1246        .layer(
1247            ServiceBuilder::new()
1248                .layer(Extension(server.app_state.clone()))
1249                .layer(middleware::from_fn(auth::validate_header)),
1250        )
1251        .route("/metrics", get(handle_metrics))
1252        .layer(Extension(server))
1253}
1254
1255pub async fn handle_websocket_request(
1256    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1257    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1258    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1259    Extension(server): Extension<Arc<Server>>,
1260    Extension(principal): Extension<Principal>,
1261    ws: WebSocketUpgrade,
1262) -> axum::response::Response {
1263    if protocol_version != rpc::PROTOCOL_VERSION {
1264        return (
1265            StatusCode::UPGRADE_REQUIRED,
1266            "client must be upgraded".to_string(),
1267        )
1268            .into_response();
1269    }
1270
1271    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1272        return (
1273            StatusCode::UPGRADE_REQUIRED,
1274            "no version header found".to_string(),
1275        )
1276            .into_response();
1277    };
1278
1279    if !version.can_collaborate() {
1280        return (
1281            StatusCode::UPGRADE_REQUIRED,
1282            "client must be upgraded".to_string(),
1283        )
1284            .into_response();
1285    }
1286
1287    let socket_address = socket_address.to_string();
1288    ws.on_upgrade(move |socket| {
1289        let socket = socket
1290            .map_ok(to_tungstenite_message)
1291            .err_into()
1292            .with(|message| async move { Ok(to_axum_message(message)) });
1293        let connection = Connection::new(Box::pin(socket));
1294        async move {
1295            server
1296                .handle_connection(
1297                    connection,
1298                    socket_address,
1299                    principal,
1300                    version,
1301                    None,
1302                    Executor::Production,
1303                )
1304                .await;
1305        }
1306    })
1307}
1308
1309pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1310    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1311    let connections_metric = CONNECTIONS_METRIC
1312        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1313
1314    let connections = server
1315        .connection_pool
1316        .lock()
1317        .connections()
1318        .filter(|connection| !connection.admin)
1319        .count();
1320    connections_metric.set(connections as _);
1321
1322    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1323    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1324        register_int_gauge!(
1325            "shared_projects",
1326            "number of open projects with one or more guests"
1327        )
1328        .unwrap()
1329    });
1330
1331    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1332    shared_projects_metric.set(shared_projects as _);
1333
1334    let encoder = prometheus::TextEncoder::new();
1335    let metric_families = prometheus::gather();
1336    let encoded_metrics = encoder
1337        .encode_to_string(&metric_families)
1338        .map_err(|err| anyhow!("{}", err))?;
1339    Ok(encoded_metrics)
1340}
1341
1342#[instrument(err, skip(executor))]
1343async fn connection_lost(
1344    session: Session,
1345    mut teardown: watch::Receiver<bool>,
1346    executor: Executor,
1347) -> Result<()> {
1348    session.peer.disconnect(session.connection_id);
1349    session
1350        .connection_pool()
1351        .await
1352        .remove_connection(session.connection_id)?;
1353
1354    session
1355        .db()
1356        .await
1357        .connection_lost(session.connection_id)
1358        .await
1359        .trace_err();
1360
1361    futures::select_biased! {
1362        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1363            match &session.principal {
1364                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1365                    let session = session.for_user().unwrap();
1366
1367                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1368                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1369                    leave_channel_buffers_for_session(&session)
1370                        .await
1371                        .trace_err();
1372
1373                    if !session
1374                        .connection_pool()
1375                        .await
1376                        .is_user_online(session.user_id())
1377                    {
1378                        let db = session.db().await;
1379                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1380                            room_updated(&room, &session.peer);
1381                        }
1382                    }
1383
1384                    update_user_contacts(session.user_id(), &session).await?;
1385                },
1386            Principal::DevServer(dev_server) => {
1387                lost_dev_server_connection(&session).await?;
1388                update_dev_server_status(&dev_server, proto::DevServerStatus::Offline, &session)
1389                    .await;
1390            },
1391        }
1392        },
1393        _ = teardown.changed().fuse() => {}
1394    }
1395
1396    Ok(())
1397}
1398
1399/// Acknowledges a ping from a client, used to keep the connection alive.
1400async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1401    response.send(proto::Ack {})?;
1402    Ok(())
1403}
1404
1405/// Creates a new room for calling (outside of channels)
1406async fn create_room(
1407    _request: proto::CreateRoom,
1408    response: Response<proto::CreateRoom>,
1409    session: UserSession,
1410) -> Result<()> {
1411    let live_kit_room = nanoid::nanoid!(30);
1412
1413    let live_kit_connection_info = util::maybe!(async {
1414        let live_kit = session.live_kit_client.as_ref();
1415        let live_kit = live_kit?;
1416        let user_id = session.user_id().to_string();
1417
1418        let token = live_kit
1419            .room_token(&live_kit_room, &user_id.to_string())
1420            .trace_err()?;
1421
1422        Some(proto::LiveKitConnectionInfo {
1423            server_url: live_kit.url().into(),
1424            token,
1425            can_publish: true,
1426        })
1427    })
1428    .await;
1429
1430    let room = session
1431        .db()
1432        .await
1433        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1434        .await?;
1435
1436    response.send(proto::CreateRoomResponse {
1437        room: Some(room.clone()),
1438        live_kit_connection_info,
1439    })?;
1440
1441    update_user_contacts(session.user_id(), &session).await?;
1442    Ok(())
1443}
1444
1445/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1446async fn join_room(
1447    request: proto::JoinRoom,
1448    response: Response<proto::JoinRoom>,
1449    session: UserSession,
1450) -> Result<()> {
1451    let room_id = RoomId::from_proto(request.id);
1452
1453    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1454
1455    if let Some(channel_id) = channel_id {
1456        return join_channel_internal(channel_id, Box::new(response), session).await;
1457    }
1458
1459    let joined_room = {
1460        let room = session
1461            .db()
1462            .await
1463            .join_room(room_id, session.user_id(), session.connection_id)
1464            .await?;
1465        room_updated(&room.room, &session.peer);
1466        room.into_inner()
1467    };
1468
1469    for connection_id in session
1470        .connection_pool()
1471        .await
1472        .user_connection_ids(session.user_id())
1473    {
1474        session
1475            .peer
1476            .send(
1477                connection_id,
1478                proto::CallCanceled {
1479                    room_id: room_id.to_proto(),
1480                },
1481            )
1482            .trace_err();
1483    }
1484
1485    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1486        if let Some(token) = live_kit
1487            .room_token(
1488                &joined_room.room.live_kit_room,
1489                &session.user_id().to_string(),
1490            )
1491            .trace_err()
1492        {
1493            Some(proto::LiveKitConnectionInfo {
1494                server_url: live_kit.url().into(),
1495                token,
1496                can_publish: true,
1497            })
1498        } else {
1499            None
1500        }
1501    } else {
1502        None
1503    };
1504
1505    response.send(proto::JoinRoomResponse {
1506        room: Some(joined_room.room),
1507        channel_id: None,
1508        live_kit_connection_info,
1509    })?;
1510
1511    update_user_contacts(session.user_id(), &session).await?;
1512    Ok(())
1513}
1514
1515/// Rejoin room is used to reconnect to a room after connection errors.
1516async fn rejoin_room(
1517    request: proto::RejoinRoom,
1518    response: Response<proto::RejoinRoom>,
1519    session: UserSession,
1520) -> Result<()> {
1521    let room;
1522    let channel;
1523    {
1524        let mut rejoined_room = session
1525            .db()
1526            .await
1527            .rejoin_room(request, session.user_id(), session.connection_id)
1528            .await?;
1529
1530        response.send(proto::RejoinRoomResponse {
1531            room: Some(rejoined_room.room.clone()),
1532            reshared_projects: rejoined_room
1533                .reshared_projects
1534                .iter()
1535                .map(|project| proto::ResharedProject {
1536                    id: project.id.to_proto(),
1537                    collaborators: project
1538                        .collaborators
1539                        .iter()
1540                        .map(|collaborator| collaborator.to_proto())
1541                        .collect(),
1542                })
1543                .collect(),
1544            rejoined_projects: rejoined_room
1545                .rejoined_projects
1546                .iter()
1547                .map(|rejoined_project| rejoined_project.to_proto())
1548                .collect(),
1549        })?;
1550        room_updated(&rejoined_room.room, &session.peer);
1551
1552        for project in &rejoined_room.reshared_projects {
1553            for collaborator in &project.collaborators {
1554                session
1555                    .peer
1556                    .send(
1557                        collaborator.connection_id,
1558                        proto::UpdateProjectCollaborator {
1559                            project_id: project.id.to_proto(),
1560                            old_peer_id: Some(project.old_connection_id.into()),
1561                            new_peer_id: Some(session.connection_id.into()),
1562                        },
1563                    )
1564                    .trace_err();
1565            }
1566
1567            broadcast(
1568                Some(session.connection_id),
1569                project
1570                    .collaborators
1571                    .iter()
1572                    .map(|collaborator| collaborator.connection_id),
1573                |connection_id| {
1574                    session.peer.forward_send(
1575                        session.connection_id,
1576                        connection_id,
1577                        proto::UpdateProject {
1578                            project_id: project.id.to_proto(),
1579                            worktrees: project.worktrees.clone(),
1580                        },
1581                    )
1582                },
1583            );
1584        }
1585
1586        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1587
1588        let rejoined_room = rejoined_room.into_inner();
1589
1590        room = rejoined_room.room;
1591        channel = rejoined_room.channel;
1592    }
1593
1594    if let Some(channel) = channel {
1595        channel_updated(
1596            &channel,
1597            &room,
1598            &session.peer,
1599            &*session.connection_pool().await,
1600        );
1601    }
1602
1603    update_user_contacts(session.user_id(), &session).await?;
1604    Ok(())
1605}
1606
1607fn notify_rejoined_projects(
1608    rejoined_projects: &mut Vec<RejoinedProject>,
1609    session: &UserSession,
1610) -> Result<()> {
1611    for project in rejoined_projects.iter() {
1612        for collaborator in &project.collaborators {
1613            session
1614                .peer
1615                .send(
1616                    collaborator.connection_id,
1617                    proto::UpdateProjectCollaborator {
1618                        project_id: project.id.to_proto(),
1619                        old_peer_id: Some(project.old_connection_id.into()),
1620                        new_peer_id: Some(session.connection_id.into()),
1621                    },
1622                )
1623                .trace_err();
1624        }
1625    }
1626
1627    for project in rejoined_projects {
1628        for worktree in mem::take(&mut project.worktrees) {
1629            #[cfg(any(test, feature = "test-support"))]
1630            const MAX_CHUNK_SIZE: usize = 2;
1631            #[cfg(not(any(test, feature = "test-support")))]
1632            const MAX_CHUNK_SIZE: usize = 256;
1633
1634            // Stream this worktree's entries.
1635            let message = proto::UpdateWorktree {
1636                project_id: project.id.to_proto(),
1637                worktree_id: worktree.id,
1638                abs_path: worktree.abs_path.clone(),
1639                root_name: worktree.root_name,
1640                updated_entries: worktree.updated_entries,
1641                removed_entries: worktree.removed_entries,
1642                scan_id: worktree.scan_id,
1643                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1644                updated_repositories: worktree.updated_repositories,
1645                removed_repositories: worktree.removed_repositories,
1646            };
1647            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1648                session.peer.send(session.connection_id, update.clone())?;
1649            }
1650
1651            // Stream this worktree's diagnostics.
1652            for summary in worktree.diagnostic_summaries {
1653                session.peer.send(
1654                    session.connection_id,
1655                    proto::UpdateDiagnosticSummary {
1656                        project_id: project.id.to_proto(),
1657                        worktree_id: worktree.id,
1658                        summary: Some(summary),
1659                    },
1660                )?;
1661            }
1662
1663            for settings_file in worktree.settings_files {
1664                session.peer.send(
1665                    session.connection_id,
1666                    proto::UpdateWorktreeSettings {
1667                        project_id: project.id.to_proto(),
1668                        worktree_id: worktree.id,
1669                        path: settings_file.path,
1670                        content: Some(settings_file.content),
1671                    },
1672                )?;
1673            }
1674        }
1675
1676        for language_server in &project.language_servers {
1677            session.peer.send(
1678                session.connection_id,
1679                proto::UpdateLanguageServer {
1680                    project_id: project.id.to_proto(),
1681                    language_server_id: language_server.id,
1682                    variant: Some(
1683                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1684                            proto::LspDiskBasedDiagnosticsUpdated {},
1685                        ),
1686                    ),
1687                },
1688            )?;
1689        }
1690    }
1691    Ok(())
1692}
1693
1694/// leave room disconnects from the room.
1695async fn leave_room(
1696    _: proto::LeaveRoom,
1697    response: Response<proto::LeaveRoom>,
1698    session: UserSession,
1699) -> Result<()> {
1700    leave_room_for_session(&session, session.connection_id).await?;
1701    response.send(proto::Ack {})?;
1702    Ok(())
1703}
1704
1705/// Updates the permissions of someone else in the room.
1706async fn set_room_participant_role(
1707    request: proto::SetRoomParticipantRole,
1708    response: Response<proto::SetRoomParticipantRole>,
1709    session: UserSession,
1710) -> Result<()> {
1711    let user_id = UserId::from_proto(request.user_id);
1712    let role = ChannelRole::from(request.role());
1713
1714    let (live_kit_room, can_publish) = {
1715        let room = session
1716            .db()
1717            .await
1718            .set_room_participant_role(
1719                session.user_id(),
1720                RoomId::from_proto(request.room_id),
1721                user_id,
1722                role,
1723            )
1724            .await?;
1725
1726        let live_kit_room = room.live_kit_room.clone();
1727        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1728        room_updated(&room, &session.peer);
1729        (live_kit_room, can_publish)
1730    };
1731
1732    if let Some(live_kit) = session.live_kit_client.as_ref() {
1733        live_kit
1734            .update_participant(
1735                live_kit_room.clone(),
1736                request.user_id.to_string(),
1737                live_kit_server::proto::ParticipantPermission {
1738                    can_subscribe: true,
1739                    can_publish,
1740                    can_publish_data: can_publish,
1741                    hidden: false,
1742                    recorder: false,
1743                },
1744            )
1745            .await
1746            .trace_err();
1747    }
1748
1749    response.send(proto::Ack {})?;
1750    Ok(())
1751}
1752
1753/// Call someone else into the current room
1754async fn call(
1755    request: proto::Call,
1756    response: Response<proto::Call>,
1757    session: UserSession,
1758) -> Result<()> {
1759    let room_id = RoomId::from_proto(request.room_id);
1760    let calling_user_id = session.user_id();
1761    let calling_connection_id = session.connection_id;
1762    let called_user_id = UserId::from_proto(request.called_user_id);
1763    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1764    if !session
1765        .db()
1766        .await
1767        .has_contact(calling_user_id, called_user_id)
1768        .await?
1769    {
1770        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1771    }
1772
1773    let incoming_call = {
1774        let (room, incoming_call) = &mut *session
1775            .db()
1776            .await
1777            .call(
1778                room_id,
1779                calling_user_id,
1780                calling_connection_id,
1781                called_user_id,
1782                initial_project_id,
1783            )
1784            .await?;
1785        room_updated(&room, &session.peer);
1786        mem::take(incoming_call)
1787    };
1788    update_user_contacts(called_user_id, &session).await?;
1789
1790    let mut calls = session
1791        .connection_pool()
1792        .await
1793        .user_connection_ids(called_user_id)
1794        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1795        .collect::<FuturesUnordered<_>>();
1796
1797    while let Some(call_response) = calls.next().await {
1798        match call_response.as_ref() {
1799            Ok(_) => {
1800                response.send(proto::Ack {})?;
1801                return Ok(());
1802            }
1803            Err(_) => {
1804                call_response.trace_err();
1805            }
1806        }
1807    }
1808
1809    {
1810        let room = session
1811            .db()
1812            .await
1813            .call_failed(room_id, called_user_id)
1814            .await?;
1815        room_updated(&room, &session.peer);
1816    }
1817    update_user_contacts(called_user_id, &session).await?;
1818
1819    Err(anyhow!("failed to ring user"))?
1820}
1821
1822/// Cancel an outgoing call.
1823async fn cancel_call(
1824    request: proto::CancelCall,
1825    response: Response<proto::CancelCall>,
1826    session: UserSession,
1827) -> Result<()> {
1828    let called_user_id = UserId::from_proto(request.called_user_id);
1829    let room_id = RoomId::from_proto(request.room_id);
1830    {
1831        let room = session
1832            .db()
1833            .await
1834            .cancel_call(room_id, session.connection_id, called_user_id)
1835            .await?;
1836        room_updated(&room, &session.peer);
1837    }
1838
1839    for connection_id in session
1840        .connection_pool()
1841        .await
1842        .user_connection_ids(called_user_id)
1843    {
1844        session
1845            .peer
1846            .send(
1847                connection_id,
1848                proto::CallCanceled {
1849                    room_id: room_id.to_proto(),
1850                },
1851            )
1852            .trace_err();
1853    }
1854    response.send(proto::Ack {})?;
1855
1856    update_user_contacts(called_user_id, &session).await?;
1857    Ok(())
1858}
1859
1860/// Decline an incoming call.
1861async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1862    let room_id = RoomId::from_proto(message.room_id);
1863    {
1864        let room = session
1865            .db()
1866            .await
1867            .decline_call(Some(room_id), session.user_id())
1868            .await?
1869            .ok_or_else(|| anyhow!("failed to decline call"))?;
1870        room_updated(&room, &session.peer);
1871    }
1872
1873    for connection_id in session
1874        .connection_pool()
1875        .await
1876        .user_connection_ids(session.user_id())
1877    {
1878        session
1879            .peer
1880            .send(
1881                connection_id,
1882                proto::CallCanceled {
1883                    room_id: room_id.to_proto(),
1884                },
1885            )
1886            .trace_err();
1887    }
1888    update_user_contacts(session.user_id(), &session).await?;
1889    Ok(())
1890}
1891
1892/// Updates other participants in the room with your current location.
1893async fn update_participant_location(
1894    request: proto::UpdateParticipantLocation,
1895    response: Response<proto::UpdateParticipantLocation>,
1896    session: UserSession,
1897) -> Result<()> {
1898    let room_id = RoomId::from_proto(request.room_id);
1899    let location = request
1900        .location
1901        .ok_or_else(|| anyhow!("invalid location"))?;
1902
1903    let db = session.db().await;
1904    let room = db
1905        .update_room_participant_location(room_id, session.connection_id, location)
1906        .await?;
1907
1908    room_updated(&room, &session.peer);
1909    response.send(proto::Ack {})?;
1910    Ok(())
1911}
1912
1913/// Share a project into the room.
1914async fn share_project(
1915    request: proto::ShareProject,
1916    response: Response<proto::ShareProject>,
1917    session: UserSession,
1918) -> Result<()> {
1919    let (project_id, room) = &*session
1920        .db()
1921        .await
1922        .share_project(
1923            RoomId::from_proto(request.room_id),
1924            session.connection_id,
1925            &request.worktrees,
1926        )
1927        .await?;
1928    response.send(proto::ShareProjectResponse {
1929        project_id: project_id.to_proto(),
1930    })?;
1931    room_updated(&room, &session.peer);
1932
1933    Ok(())
1934}
1935
1936/// Unshare a project from the room.
1937async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1938    let project_id = ProjectId::from_proto(message.project_id);
1939    unshare_project_internal(project_id, &session).await
1940}
1941
1942async fn unshare_project_internal(project_id: ProjectId, session: &Session) -> Result<()> {
1943    let (room, guest_connection_ids) = &*session
1944        .db()
1945        .await
1946        .unshare_project(project_id, session.connection_id)
1947        .await?;
1948
1949    let message = proto::UnshareProject {
1950        project_id: project_id.to_proto(),
1951    };
1952
1953    broadcast(
1954        Some(session.connection_id),
1955        guest_connection_ids.iter().copied(),
1956        |conn_id| session.peer.send(conn_id, message.clone()),
1957    );
1958    if let Some(room) = room {
1959        room_updated(room, &session.peer);
1960    }
1961
1962    Ok(())
1963}
1964
1965/// Share a project into the room.
1966async fn share_remote_project(
1967    request: proto::ShareRemoteProject,
1968    response: Response<proto::ShareRemoteProject>,
1969    session: DevServerSession,
1970) -> Result<()> {
1971    let remote_project = session
1972        .db()
1973        .await
1974        .share_remote_project(
1975            RemoteProjectId::from_proto(request.remote_project_id),
1976            session.dev_server_id(),
1977            session.connection_id,
1978            &request.worktrees,
1979        )
1980        .await?;
1981    let Some(project_id) = remote_project.project_id else {
1982        return Err(anyhow!("failed to share remote project"))?;
1983    };
1984
1985    for (connection_id, _) in session
1986        .connection_pool()
1987        .await
1988        .channel_connection_ids(ChannelId::from_proto(remote_project.channel_id))
1989    {
1990        session
1991            .peer
1992            .send(
1993                connection_id,
1994                proto::UpdateChannels {
1995                    remote_projects: vec![remote_project.clone()],
1996                    ..Default::default()
1997                },
1998            )
1999            .trace_err();
2000    }
2001
2002    response.send(proto::ShareProjectResponse { project_id })?;
2003
2004    Ok(())
2005}
2006
2007/// Join someone elses shared project.
2008async fn join_project(
2009    request: proto::JoinProject,
2010    response: Response<proto::JoinProject>,
2011    session: UserSession,
2012) -> Result<()> {
2013    let project_id = ProjectId::from_proto(request.project_id);
2014
2015    tracing::info!(%project_id, "join project");
2016
2017    let db = session.db().await;
2018    let (project, replica_id) = &mut *db
2019        .join_project(project_id, session.connection_id, session.user_id())
2020        .await?;
2021    drop(db);
2022    tracing::info!(%project_id, "join remote project");
2023    join_project_internal(response, session, project, replica_id)
2024}
2025
2026trait JoinProjectInternalResponse {
2027    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2028}
2029impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2030    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2031        Response::<proto::JoinProject>::send(self, result)
2032    }
2033}
2034impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2035    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2036        Response::<proto::JoinHostedProject>::send(self, result)
2037    }
2038}
2039
2040fn join_project_internal(
2041    response: impl JoinProjectInternalResponse,
2042    session: UserSession,
2043    project: &mut Project,
2044    replica_id: &ReplicaId,
2045) -> Result<()> {
2046    let collaborators = project
2047        .collaborators
2048        .iter()
2049        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2050        .map(|collaborator| collaborator.to_proto())
2051        .collect::<Vec<_>>();
2052    let project_id = project.id;
2053    let guest_user_id = session.user_id();
2054
2055    let worktrees = project
2056        .worktrees
2057        .iter()
2058        .map(|(id, worktree)| proto::WorktreeMetadata {
2059            id: *id,
2060            root_name: worktree.root_name.clone(),
2061            visible: worktree.visible,
2062            abs_path: worktree.abs_path.clone(),
2063        })
2064        .collect::<Vec<_>>();
2065
2066    for collaborator in &collaborators {
2067        session
2068            .peer
2069            .send(
2070                collaborator.peer_id.unwrap().into(),
2071                proto::AddProjectCollaborator {
2072                    project_id: project_id.to_proto(),
2073                    collaborator: Some(proto::Collaborator {
2074                        peer_id: Some(session.connection_id.into()),
2075                        replica_id: replica_id.0 as u32,
2076                        user_id: guest_user_id.to_proto(),
2077                    }),
2078                },
2079            )
2080            .trace_err();
2081    }
2082
2083    // First, we send the metadata associated with each worktree.
2084    response.send(proto::JoinProjectResponse {
2085        project_id: project.id.0 as u64,
2086        worktrees: worktrees.clone(),
2087        replica_id: replica_id.0 as u32,
2088        collaborators: collaborators.clone(),
2089        language_servers: project.language_servers.clone(),
2090        role: project.role.into(), // todo
2091    })?;
2092
2093    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2094        #[cfg(any(test, feature = "test-support"))]
2095        const MAX_CHUNK_SIZE: usize = 2;
2096        #[cfg(not(any(test, feature = "test-support")))]
2097        const MAX_CHUNK_SIZE: usize = 256;
2098
2099        // Stream this worktree's entries.
2100        let message = proto::UpdateWorktree {
2101            project_id: project_id.to_proto(),
2102            worktree_id,
2103            abs_path: worktree.abs_path.clone(),
2104            root_name: worktree.root_name,
2105            updated_entries: worktree.entries,
2106            removed_entries: Default::default(),
2107            scan_id: worktree.scan_id,
2108            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2109            updated_repositories: worktree.repository_entries.into_values().collect(),
2110            removed_repositories: Default::default(),
2111        };
2112        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2113            session.peer.send(session.connection_id, update.clone())?;
2114        }
2115
2116        // Stream this worktree's diagnostics.
2117        for summary in worktree.diagnostic_summaries {
2118            session.peer.send(
2119                session.connection_id,
2120                proto::UpdateDiagnosticSummary {
2121                    project_id: project_id.to_proto(),
2122                    worktree_id: worktree.id,
2123                    summary: Some(summary),
2124                },
2125            )?;
2126        }
2127
2128        for settings_file in worktree.settings_files {
2129            session.peer.send(
2130                session.connection_id,
2131                proto::UpdateWorktreeSettings {
2132                    project_id: project_id.to_proto(),
2133                    worktree_id: worktree.id,
2134                    path: settings_file.path,
2135                    content: Some(settings_file.content),
2136                },
2137            )?;
2138        }
2139    }
2140
2141    for language_server in &project.language_servers {
2142        session.peer.send(
2143            session.connection_id,
2144            proto::UpdateLanguageServer {
2145                project_id: project_id.to_proto(),
2146                language_server_id: language_server.id,
2147                variant: Some(
2148                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2149                        proto::LspDiskBasedDiagnosticsUpdated {},
2150                    ),
2151                ),
2152            },
2153        )?;
2154    }
2155
2156    Ok(())
2157}
2158
2159/// Leave someone elses shared project.
2160async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2161    let sender_id = session.connection_id;
2162    let project_id = ProjectId::from_proto(request.project_id);
2163    let db = session.db().await;
2164    if db.is_hosted_project(project_id).await? {
2165        let project = db.leave_hosted_project(project_id, sender_id).await?;
2166        project_left(&project, &session);
2167        return Ok(());
2168    }
2169
2170    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2171    tracing::info!(
2172        %project_id,
2173        host_user_id = ?project.host_user_id,
2174        host_connection_id = ?project.host_connection_id,
2175        "leave project"
2176    );
2177
2178    project_left(&project, &session);
2179    if let Some(room) = room {
2180        room_updated(&room, &session.peer);
2181    }
2182
2183    Ok(())
2184}
2185
2186async fn join_hosted_project(
2187    request: proto::JoinHostedProject,
2188    response: Response<proto::JoinHostedProject>,
2189    session: UserSession,
2190) -> Result<()> {
2191    let (mut project, replica_id) = session
2192        .db()
2193        .await
2194        .join_hosted_project(
2195            ProjectId(request.project_id as i32),
2196            session.user_id(),
2197            session.connection_id,
2198        )
2199        .await?;
2200
2201    join_project_internal(response, session, &mut project, &replica_id)
2202}
2203
2204async fn create_remote_project(
2205    request: proto::CreateRemoteProject,
2206    response: Response<proto::CreateRemoteProject>,
2207    session: UserSession,
2208) -> Result<()> {
2209    let (channel, remote_project) = session
2210        .db()
2211        .await
2212        .create_remote_project(
2213            ChannelId(request.channel_id as i32),
2214            DevServerId(request.dev_server_id as i32),
2215            &request.name,
2216            &request.path,
2217            session.user_id(),
2218        )
2219        .await?;
2220
2221    let projects = session
2222        .db()
2223        .await
2224        .get_remote_projects_for_dev_server(remote_project.dev_server_id)
2225        .await?;
2226
2227    let update = proto::UpdateChannels {
2228        remote_projects: vec![remote_project.to_proto(None)],
2229        ..Default::default()
2230    };
2231    let connection_pool = session.connection_pool().await;
2232    for (connection_id, role) in connection_pool.channel_connection_ids(channel.root_id()) {
2233        if role.can_see_all_descendants() {
2234            session.peer.send(connection_id, update.clone())?;
2235        }
2236    }
2237
2238    let dev_server_id = remote_project.dev_server_id;
2239    let dev_server_connection_id = connection_pool.dev_server_connection_id(dev_server_id);
2240    if let Some(dev_server_connection_id) = dev_server_connection_id {
2241        session.peer.send(
2242            dev_server_connection_id,
2243            proto::DevServerInstructions { projects },
2244        )?;
2245    }
2246
2247    response.send(proto::CreateRemoteProjectResponse {
2248        remote_project: Some(remote_project.to_proto(None)),
2249    })?;
2250    Ok(())
2251}
2252
2253async fn create_dev_server(
2254    request: proto::CreateDevServer,
2255    response: Response<proto::CreateDevServer>,
2256    session: UserSession,
2257) -> Result<()> {
2258    let access_token = auth::random_token();
2259    let hashed_access_token = auth::hash_access_token(&access_token);
2260
2261    let (channel, dev_server) = session
2262        .db()
2263        .await
2264        .create_dev_server(
2265            ChannelId(request.channel_id as i32),
2266            &request.name,
2267            &hashed_access_token,
2268            session.user_id(),
2269        )
2270        .await?;
2271
2272    let update = proto::UpdateChannels {
2273        dev_servers: vec![dev_server.to_proto(proto::DevServerStatus::Offline)],
2274        ..Default::default()
2275    };
2276    let connection_pool = session.connection_pool().await;
2277    for (connection_id, role) in connection_pool.channel_connection_ids(channel.root_id()) {
2278        if role.can_see_channel(channel.visibility) {
2279            session.peer.send(connection_id, update.clone())?;
2280        }
2281    }
2282
2283    response.send(proto::CreateDevServerResponse {
2284        dev_server_id: dev_server.id.0 as u64,
2285        channel_id: request.channel_id,
2286        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2287        name: request.name.clone(),
2288    })?;
2289    Ok(())
2290}
2291
2292async fn rejoin_remote_projects(
2293    request: proto::RejoinRemoteProjects,
2294    response: Response<proto::RejoinRemoteProjects>,
2295    session: UserSession,
2296) -> Result<()> {
2297    let mut rejoined_projects = {
2298        let db = session.db().await;
2299        db.rejoin_remote_projects(
2300            &request.rejoined_projects,
2301            session.user_id(),
2302            session.0.connection_id,
2303        )
2304        .await?
2305    };
2306    notify_rejoined_projects(&mut rejoined_projects, &session)?;
2307
2308    response.send(proto::RejoinRemoteProjectsResponse {
2309        rejoined_projects: rejoined_projects
2310            .into_iter()
2311            .map(|project| project.to_proto())
2312            .collect(),
2313    })
2314}
2315
2316async fn reconnect_dev_server(
2317    request: proto::ReconnectDevServer,
2318    response: Response<proto::ReconnectDevServer>,
2319    session: DevServerSession,
2320) -> Result<()> {
2321    let reshared_projects = {
2322        let db = session.db().await;
2323        db.reshare_remote_projects(
2324            &request.reshared_projects,
2325            session.dev_server_id(),
2326            session.0.connection_id,
2327        )
2328        .await?
2329    };
2330
2331    for project in &reshared_projects {
2332        for collaborator in &project.collaborators {
2333            session
2334                .peer
2335                .send(
2336                    collaborator.connection_id,
2337                    proto::UpdateProjectCollaborator {
2338                        project_id: project.id.to_proto(),
2339                        old_peer_id: Some(project.old_connection_id.into()),
2340                        new_peer_id: Some(session.connection_id.into()),
2341                    },
2342                )
2343                .trace_err();
2344        }
2345
2346        broadcast(
2347            Some(session.connection_id),
2348            project
2349                .collaborators
2350                .iter()
2351                .map(|collaborator| collaborator.connection_id),
2352            |connection_id| {
2353                session.peer.forward_send(
2354                    session.connection_id,
2355                    connection_id,
2356                    proto::UpdateProject {
2357                        project_id: project.id.to_proto(),
2358                        worktrees: project.worktrees.clone(),
2359                    },
2360                )
2361            },
2362        );
2363    }
2364
2365    response.send(proto::ReconnectDevServerResponse {
2366        reshared_projects: reshared_projects
2367            .iter()
2368            .map(|project| proto::ResharedProject {
2369                id: project.id.to_proto(),
2370                collaborators: project
2371                    .collaborators
2372                    .iter()
2373                    .map(|collaborator| collaborator.to_proto())
2374                    .collect(),
2375            })
2376            .collect(),
2377    })?;
2378
2379    Ok(())
2380}
2381
2382async fn shutdown_dev_server(
2383    _: proto::ShutdownDevServer,
2384    response: Response<proto::ShutdownDevServer>,
2385    session: DevServerSession,
2386) -> Result<()> {
2387    response.send(proto::Ack {})?;
2388    let (remote_projects, dev_server) = {
2389        let dev_server_id = session.dev_server_id();
2390        let db = session.db().await;
2391        let remote_projects = db.get_remote_projects_for_dev_server(dev_server_id).await?;
2392        let dev_server = db.get_dev_server(dev_server_id).await?;
2393        (remote_projects, dev_server)
2394    };
2395
2396    for project_id in remote_projects.iter().filter_map(|p| p.project_id) {
2397        unshare_project_internal(ProjectId::from_proto(project_id), &session.0).await?;
2398    }
2399
2400    let update = proto::UpdateChannels {
2401        remote_projects,
2402        dev_servers: vec![dev_server.to_proto(proto::DevServerStatus::Offline)],
2403        ..Default::default()
2404    };
2405
2406    for (connection_id, _) in session
2407        .connection_pool()
2408        .await
2409        .channel_connection_ids(dev_server.channel_id)
2410    {
2411        session.peer.send(connection_id, update.clone()).trace_err();
2412    }
2413
2414    Ok(())
2415}
2416
2417/// Updates other participants with changes to the project
2418async fn update_project(
2419    request: proto::UpdateProject,
2420    response: Response<proto::UpdateProject>,
2421    session: Session,
2422) -> Result<()> {
2423    let project_id = ProjectId::from_proto(request.project_id);
2424    let (room, guest_connection_ids) = &*session
2425        .db()
2426        .await
2427        .update_project(project_id, session.connection_id, &request.worktrees)
2428        .await?;
2429    broadcast(
2430        Some(session.connection_id),
2431        guest_connection_ids.iter().copied(),
2432        |connection_id| {
2433            session
2434                .peer
2435                .forward_send(session.connection_id, connection_id, request.clone())
2436        },
2437    );
2438    if let Some(room) = room {
2439        room_updated(&room, &session.peer);
2440    }
2441    response.send(proto::Ack {})?;
2442
2443    Ok(())
2444}
2445
2446/// Updates other participants with changes to the worktree
2447async fn update_worktree(
2448    request: proto::UpdateWorktree,
2449    response: Response<proto::UpdateWorktree>,
2450    session: Session,
2451) -> Result<()> {
2452    let guest_connection_ids = session
2453        .db()
2454        .await
2455        .update_worktree(&request, session.connection_id)
2456        .await?;
2457
2458    broadcast(
2459        Some(session.connection_id),
2460        guest_connection_ids.iter().copied(),
2461        |connection_id| {
2462            session
2463                .peer
2464                .forward_send(session.connection_id, connection_id, request.clone())
2465        },
2466    );
2467    response.send(proto::Ack {})?;
2468    Ok(())
2469}
2470
2471/// Updates other participants with changes to the diagnostics
2472async fn update_diagnostic_summary(
2473    message: proto::UpdateDiagnosticSummary,
2474    session: Session,
2475) -> Result<()> {
2476    let guest_connection_ids = session
2477        .db()
2478        .await
2479        .update_diagnostic_summary(&message, session.connection_id)
2480        .await?;
2481
2482    broadcast(
2483        Some(session.connection_id),
2484        guest_connection_ids.iter().copied(),
2485        |connection_id| {
2486            session
2487                .peer
2488                .forward_send(session.connection_id, connection_id, message.clone())
2489        },
2490    );
2491
2492    Ok(())
2493}
2494
2495/// Updates other participants with changes to the worktree settings
2496async fn update_worktree_settings(
2497    message: proto::UpdateWorktreeSettings,
2498    session: Session,
2499) -> Result<()> {
2500    let guest_connection_ids = session
2501        .db()
2502        .await
2503        .update_worktree_settings(&message, session.connection_id)
2504        .await?;
2505
2506    broadcast(
2507        Some(session.connection_id),
2508        guest_connection_ids.iter().copied(),
2509        |connection_id| {
2510            session
2511                .peer
2512                .forward_send(session.connection_id, connection_id, message.clone())
2513        },
2514    );
2515
2516    Ok(())
2517}
2518
2519/// Notify other participants that a  language server has started.
2520async fn start_language_server(
2521    request: proto::StartLanguageServer,
2522    session: Session,
2523) -> Result<()> {
2524    let guest_connection_ids = session
2525        .db()
2526        .await
2527        .start_language_server(&request, session.connection_id)
2528        .await?;
2529
2530    broadcast(
2531        Some(session.connection_id),
2532        guest_connection_ids.iter().copied(),
2533        |connection_id| {
2534            session
2535                .peer
2536                .forward_send(session.connection_id, connection_id, request.clone())
2537        },
2538    );
2539    Ok(())
2540}
2541
2542/// Notify other participants that a language server has changed.
2543async fn update_language_server(
2544    request: proto::UpdateLanguageServer,
2545    session: Session,
2546) -> Result<()> {
2547    let project_id = ProjectId::from_proto(request.project_id);
2548    let project_connection_ids = session
2549        .db()
2550        .await
2551        .project_connection_ids(project_id, session.connection_id, true)
2552        .await?;
2553    broadcast(
2554        Some(session.connection_id),
2555        project_connection_ids.iter().copied(),
2556        |connection_id| {
2557            session
2558                .peer
2559                .forward_send(session.connection_id, connection_id, request.clone())
2560        },
2561    );
2562    Ok(())
2563}
2564
2565/// forward a project request to the host. These requests should be read only
2566/// as guests are allowed to send them.
2567async fn forward_read_only_project_request<T>(
2568    request: T,
2569    response: Response<T>,
2570    session: UserSession,
2571) -> Result<()>
2572where
2573    T: EntityMessage + RequestMessage,
2574{
2575    let project_id = ProjectId::from_proto(request.remote_entity_id());
2576    let host_connection_id = session
2577        .db()
2578        .await
2579        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2580        .await?;
2581    let payload = session
2582        .peer
2583        .forward_request(session.connection_id, host_connection_id, request)
2584        .await?;
2585    response.send(payload)?;
2586    Ok(())
2587}
2588
2589/// forward a project request to the host. These requests are disallowed
2590/// for guests.
2591async fn forward_mutating_project_request<T>(
2592    request: T,
2593    response: Response<T>,
2594    session: UserSession,
2595) -> Result<()>
2596where
2597    T: EntityMessage + RequestMessage,
2598{
2599    let project_id = ProjectId::from_proto(request.remote_entity_id());
2600    let host_connection_id = session
2601        .db()
2602        .await
2603        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2604        .await?;
2605    let payload = session
2606        .peer
2607        .forward_request(session.connection_id, host_connection_id, request)
2608        .await?;
2609    response.send(payload)?;
2610    Ok(())
2611}
2612
2613/// Notify other participants that a new buffer has been created
2614async fn create_buffer_for_peer(
2615    request: proto::CreateBufferForPeer,
2616    session: Session,
2617) -> Result<()> {
2618    session
2619        .db()
2620        .await
2621        .check_user_is_project_host(
2622            ProjectId::from_proto(request.project_id),
2623            session.connection_id,
2624        )
2625        .await?;
2626    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2627    session
2628        .peer
2629        .forward_send(session.connection_id, peer_id.into(), request)?;
2630    Ok(())
2631}
2632
2633/// Notify other participants that a buffer has been updated. This is
2634/// allowed for guests as long as the update is limited to selections.
2635async fn update_buffer(
2636    request: proto::UpdateBuffer,
2637    response: Response<proto::UpdateBuffer>,
2638    session: Session,
2639) -> Result<()> {
2640    let project_id = ProjectId::from_proto(request.project_id);
2641    let mut capability = Capability::ReadOnly;
2642
2643    for op in request.operations.iter() {
2644        match op.variant {
2645            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2646            Some(_) => capability = Capability::ReadWrite,
2647        }
2648    }
2649
2650    let host = {
2651        let guard = session
2652            .db()
2653            .await
2654            .connections_for_buffer_update(
2655                project_id,
2656                session.principal_id(),
2657                session.connection_id,
2658                capability,
2659            )
2660            .await?;
2661
2662        let (host, guests) = &*guard;
2663
2664        broadcast(
2665            Some(session.connection_id),
2666            guests.clone(),
2667            |connection_id| {
2668                session
2669                    .peer
2670                    .forward_send(session.connection_id, connection_id, request.clone())
2671            },
2672        );
2673
2674        *host
2675    };
2676
2677    if host != session.connection_id {
2678        session
2679            .peer
2680            .forward_request(session.connection_id, host, request.clone())
2681            .await?;
2682    }
2683
2684    response.send(proto::Ack {})?;
2685    Ok(())
2686}
2687
2688/// Notify other participants that a project has been updated.
2689async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2690    request: T,
2691    session: Session,
2692) -> Result<()> {
2693    let project_id = ProjectId::from_proto(request.remote_entity_id());
2694    let project_connection_ids = session
2695        .db()
2696        .await
2697        .project_connection_ids(project_id, session.connection_id, false)
2698        .await?;
2699
2700    broadcast(
2701        Some(session.connection_id),
2702        project_connection_ids.iter().copied(),
2703        |connection_id| {
2704            session
2705                .peer
2706                .forward_send(session.connection_id, connection_id, request.clone())
2707        },
2708    );
2709    Ok(())
2710}
2711
2712/// Start following another user in a call.
2713async fn follow(
2714    request: proto::Follow,
2715    response: Response<proto::Follow>,
2716    session: UserSession,
2717) -> Result<()> {
2718    let room_id = RoomId::from_proto(request.room_id);
2719    let project_id = request.project_id.map(ProjectId::from_proto);
2720    let leader_id = request
2721        .leader_id
2722        .ok_or_else(|| anyhow!("invalid leader id"))?
2723        .into();
2724    let follower_id = session.connection_id;
2725
2726    session
2727        .db()
2728        .await
2729        .check_room_participants(room_id, leader_id, session.connection_id)
2730        .await?;
2731
2732    let response_payload = session
2733        .peer
2734        .forward_request(session.connection_id, leader_id, request)
2735        .await?;
2736    response.send(response_payload)?;
2737
2738    if let Some(project_id) = project_id {
2739        let room = session
2740            .db()
2741            .await
2742            .follow(room_id, project_id, leader_id, follower_id)
2743            .await?;
2744        room_updated(&room, &session.peer);
2745    }
2746
2747    Ok(())
2748}
2749
2750/// Stop following another user in a call.
2751async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2752    let room_id = RoomId::from_proto(request.room_id);
2753    let project_id = request.project_id.map(ProjectId::from_proto);
2754    let leader_id = request
2755        .leader_id
2756        .ok_or_else(|| anyhow!("invalid leader id"))?
2757        .into();
2758    let follower_id = session.connection_id;
2759
2760    session
2761        .db()
2762        .await
2763        .check_room_participants(room_id, leader_id, session.connection_id)
2764        .await?;
2765
2766    session
2767        .peer
2768        .forward_send(session.connection_id, leader_id, request)?;
2769
2770    if let Some(project_id) = project_id {
2771        let room = session
2772            .db()
2773            .await
2774            .unfollow(room_id, project_id, leader_id, follower_id)
2775            .await?;
2776        room_updated(&room, &session.peer);
2777    }
2778
2779    Ok(())
2780}
2781
2782/// Notify everyone following you of your current location.
2783async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2784    let room_id = RoomId::from_proto(request.room_id);
2785    let database = session.db.lock().await;
2786
2787    let connection_ids = if let Some(project_id) = request.project_id {
2788        let project_id = ProjectId::from_proto(project_id);
2789        database
2790            .project_connection_ids(project_id, session.connection_id, true)
2791            .await?
2792    } else {
2793        database
2794            .room_connection_ids(room_id, session.connection_id)
2795            .await?
2796    };
2797
2798    // For now, don't send view update messages back to that view's current leader.
2799    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2800        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2801        _ => None,
2802    });
2803
2804    for connection_id in connection_ids.iter().cloned() {
2805        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2806            session
2807                .peer
2808                .forward_send(session.connection_id, connection_id, request.clone())?;
2809        }
2810    }
2811    Ok(())
2812}
2813
2814/// Get public data about users.
2815async fn get_users(
2816    request: proto::GetUsers,
2817    response: Response<proto::GetUsers>,
2818    session: Session,
2819) -> Result<()> {
2820    let user_ids = request
2821        .user_ids
2822        .into_iter()
2823        .map(UserId::from_proto)
2824        .collect();
2825    let users = session
2826        .db()
2827        .await
2828        .get_users_by_ids(user_ids)
2829        .await?
2830        .into_iter()
2831        .map(|user| proto::User {
2832            id: user.id.to_proto(),
2833            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2834            github_login: user.github_login,
2835        })
2836        .collect();
2837    response.send(proto::UsersResponse { users })?;
2838    Ok(())
2839}
2840
2841/// Search for users (to invite) buy Github login
2842async fn fuzzy_search_users(
2843    request: proto::FuzzySearchUsers,
2844    response: Response<proto::FuzzySearchUsers>,
2845    session: UserSession,
2846) -> Result<()> {
2847    let query = request.query;
2848    let users = match query.len() {
2849        0 => vec![],
2850        1 | 2 => session
2851            .db()
2852            .await
2853            .get_user_by_github_login(&query)
2854            .await?
2855            .into_iter()
2856            .collect(),
2857        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2858    };
2859    let users = users
2860        .into_iter()
2861        .filter(|user| user.id != session.user_id())
2862        .map(|user| proto::User {
2863            id: user.id.to_proto(),
2864            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2865            github_login: user.github_login,
2866        })
2867        .collect();
2868    response.send(proto::UsersResponse { users })?;
2869    Ok(())
2870}
2871
2872/// Send a contact request to another user.
2873async fn request_contact(
2874    request: proto::RequestContact,
2875    response: Response<proto::RequestContact>,
2876    session: UserSession,
2877) -> Result<()> {
2878    let requester_id = session.user_id();
2879    let responder_id = UserId::from_proto(request.responder_id);
2880    if requester_id == responder_id {
2881        return Err(anyhow!("cannot add yourself as a contact"))?;
2882    }
2883
2884    let notifications = session
2885        .db()
2886        .await
2887        .send_contact_request(requester_id, responder_id)
2888        .await?;
2889
2890    // Update outgoing contact requests of requester
2891    let mut update = proto::UpdateContacts::default();
2892    update.outgoing_requests.push(responder_id.to_proto());
2893    for connection_id in session
2894        .connection_pool()
2895        .await
2896        .user_connection_ids(requester_id)
2897    {
2898        session.peer.send(connection_id, update.clone())?;
2899    }
2900
2901    // Update incoming contact requests of responder
2902    let mut update = proto::UpdateContacts::default();
2903    update
2904        .incoming_requests
2905        .push(proto::IncomingContactRequest {
2906            requester_id: requester_id.to_proto(),
2907        });
2908    let connection_pool = session.connection_pool().await;
2909    for connection_id in connection_pool.user_connection_ids(responder_id) {
2910        session.peer.send(connection_id, update.clone())?;
2911    }
2912
2913    send_notifications(&connection_pool, &session.peer, notifications);
2914
2915    response.send(proto::Ack {})?;
2916    Ok(())
2917}
2918
2919/// Accept or decline a contact request
2920async fn respond_to_contact_request(
2921    request: proto::RespondToContactRequest,
2922    response: Response<proto::RespondToContactRequest>,
2923    session: UserSession,
2924) -> Result<()> {
2925    let responder_id = session.user_id();
2926    let requester_id = UserId::from_proto(request.requester_id);
2927    let db = session.db().await;
2928    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2929        db.dismiss_contact_notification(responder_id, requester_id)
2930            .await?;
2931    } else {
2932        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2933
2934        let notifications = db
2935            .respond_to_contact_request(responder_id, requester_id, accept)
2936            .await?;
2937        let requester_busy = db.is_user_busy(requester_id).await?;
2938        let responder_busy = db.is_user_busy(responder_id).await?;
2939
2940        let pool = session.connection_pool().await;
2941        // Update responder with new contact
2942        let mut update = proto::UpdateContacts::default();
2943        if accept {
2944            update
2945                .contacts
2946                .push(contact_for_user(requester_id, requester_busy, &pool));
2947        }
2948        update
2949            .remove_incoming_requests
2950            .push(requester_id.to_proto());
2951        for connection_id in pool.user_connection_ids(responder_id) {
2952            session.peer.send(connection_id, update.clone())?;
2953        }
2954
2955        // Update requester with new contact
2956        let mut update = proto::UpdateContacts::default();
2957        if accept {
2958            update
2959                .contacts
2960                .push(contact_for_user(responder_id, responder_busy, &pool));
2961        }
2962        update
2963            .remove_outgoing_requests
2964            .push(responder_id.to_proto());
2965
2966        for connection_id in pool.user_connection_ids(requester_id) {
2967            session.peer.send(connection_id, update.clone())?;
2968        }
2969
2970        send_notifications(&pool, &session.peer, notifications);
2971    }
2972
2973    response.send(proto::Ack {})?;
2974    Ok(())
2975}
2976
2977/// Remove a contact.
2978async fn remove_contact(
2979    request: proto::RemoveContact,
2980    response: Response<proto::RemoveContact>,
2981    session: UserSession,
2982) -> Result<()> {
2983    let requester_id = session.user_id();
2984    let responder_id = UserId::from_proto(request.user_id);
2985    let db = session.db().await;
2986    let (contact_accepted, deleted_notification_id) =
2987        db.remove_contact(requester_id, responder_id).await?;
2988
2989    let pool = session.connection_pool().await;
2990    // Update outgoing contact requests of requester
2991    let mut update = proto::UpdateContacts::default();
2992    if contact_accepted {
2993        update.remove_contacts.push(responder_id.to_proto());
2994    } else {
2995        update
2996            .remove_outgoing_requests
2997            .push(responder_id.to_proto());
2998    }
2999    for connection_id in pool.user_connection_ids(requester_id) {
3000        session.peer.send(connection_id, update.clone())?;
3001    }
3002
3003    // Update incoming contact requests of responder
3004    let mut update = proto::UpdateContacts::default();
3005    if contact_accepted {
3006        update.remove_contacts.push(requester_id.to_proto());
3007    } else {
3008        update
3009            .remove_incoming_requests
3010            .push(requester_id.to_proto());
3011    }
3012    for connection_id in pool.user_connection_ids(responder_id) {
3013        session.peer.send(connection_id, update.clone())?;
3014        if let Some(notification_id) = deleted_notification_id {
3015            session.peer.send(
3016                connection_id,
3017                proto::DeleteNotification {
3018                    notification_id: notification_id.to_proto(),
3019                },
3020            )?;
3021        }
3022    }
3023
3024    response.send(proto::Ack {})?;
3025    Ok(())
3026}
3027
3028/// Creates a new channel.
3029async fn create_channel(
3030    request: proto::CreateChannel,
3031    response: Response<proto::CreateChannel>,
3032    session: UserSession,
3033) -> Result<()> {
3034    let db = session.db().await;
3035
3036    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3037    let (channel, membership) = db
3038        .create_channel(&request.name, parent_id, session.user_id())
3039        .await?;
3040
3041    let root_id = channel.root_id();
3042    let channel = Channel::from_model(channel);
3043
3044    response.send(proto::CreateChannelResponse {
3045        channel: Some(channel.to_proto()),
3046        parent_id: request.parent_id,
3047    })?;
3048
3049    let mut connection_pool = session.connection_pool().await;
3050    if let Some(membership) = membership {
3051        connection_pool.subscribe_to_channel(
3052            membership.user_id,
3053            membership.channel_id,
3054            membership.role,
3055        );
3056        let update = proto::UpdateUserChannels {
3057            channel_memberships: vec![proto::ChannelMembership {
3058                channel_id: membership.channel_id.to_proto(),
3059                role: membership.role.into(),
3060            }],
3061            ..Default::default()
3062        };
3063        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3064            session.peer.send(connection_id, update.clone())?;
3065        }
3066    }
3067
3068    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3069        if !role.can_see_channel(channel.visibility) {
3070            continue;
3071        }
3072
3073        let update = proto::UpdateChannels {
3074            channels: vec![channel.to_proto()],
3075            ..Default::default()
3076        };
3077        session.peer.send(connection_id, update.clone())?;
3078    }
3079
3080    Ok(())
3081}
3082
3083/// Delete a channel
3084async fn delete_channel(
3085    request: proto::DeleteChannel,
3086    response: Response<proto::DeleteChannel>,
3087    session: UserSession,
3088) -> Result<()> {
3089    let db = session.db().await;
3090
3091    let channel_id = request.channel_id;
3092    let (root_channel, removed_channels) = db
3093        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3094        .await?;
3095    response.send(proto::Ack {})?;
3096
3097    // Notify members of removed channels
3098    let mut update = proto::UpdateChannels::default();
3099    update
3100        .delete_channels
3101        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3102
3103    let connection_pool = session.connection_pool().await;
3104    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3105        session.peer.send(connection_id, update.clone())?;
3106    }
3107
3108    Ok(())
3109}
3110
3111/// Invite someone to join a channel.
3112async fn invite_channel_member(
3113    request: proto::InviteChannelMember,
3114    response: Response<proto::InviteChannelMember>,
3115    session: UserSession,
3116) -> Result<()> {
3117    let db = session.db().await;
3118    let channel_id = ChannelId::from_proto(request.channel_id);
3119    let invitee_id = UserId::from_proto(request.user_id);
3120    let InviteMemberResult {
3121        channel,
3122        notifications,
3123    } = db
3124        .invite_channel_member(
3125            channel_id,
3126            invitee_id,
3127            session.user_id(),
3128            request.role().into(),
3129        )
3130        .await?;
3131
3132    let update = proto::UpdateChannels {
3133        channel_invitations: vec![channel.to_proto()],
3134        ..Default::default()
3135    };
3136
3137    let connection_pool = session.connection_pool().await;
3138    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3139        session.peer.send(connection_id, update.clone())?;
3140    }
3141
3142    send_notifications(&connection_pool, &session.peer, notifications);
3143
3144    response.send(proto::Ack {})?;
3145    Ok(())
3146}
3147
3148/// remove someone from a channel
3149async fn remove_channel_member(
3150    request: proto::RemoveChannelMember,
3151    response: Response<proto::RemoveChannelMember>,
3152    session: UserSession,
3153) -> Result<()> {
3154    let db = session.db().await;
3155    let channel_id = ChannelId::from_proto(request.channel_id);
3156    let member_id = UserId::from_proto(request.user_id);
3157
3158    let RemoveChannelMemberResult {
3159        membership_update,
3160        notification_id,
3161    } = db
3162        .remove_channel_member(channel_id, member_id, session.user_id())
3163        .await?;
3164
3165    let mut connection_pool = session.connection_pool().await;
3166    notify_membership_updated(
3167        &mut connection_pool,
3168        membership_update,
3169        member_id,
3170        &session.peer,
3171    );
3172    for connection_id in connection_pool.user_connection_ids(member_id) {
3173        if let Some(notification_id) = notification_id {
3174            session
3175                .peer
3176                .send(
3177                    connection_id,
3178                    proto::DeleteNotification {
3179                        notification_id: notification_id.to_proto(),
3180                    },
3181                )
3182                .trace_err();
3183        }
3184    }
3185
3186    response.send(proto::Ack {})?;
3187    Ok(())
3188}
3189
3190/// Toggle the channel between public and private.
3191/// Care is taken to maintain the invariant that public channels only descend from public channels,
3192/// (though members-only channels can appear at any point in the hierarchy).
3193async fn set_channel_visibility(
3194    request: proto::SetChannelVisibility,
3195    response: Response<proto::SetChannelVisibility>,
3196    session: UserSession,
3197) -> Result<()> {
3198    let db = session.db().await;
3199    let channel_id = ChannelId::from_proto(request.channel_id);
3200    let visibility = request.visibility().into();
3201
3202    let channel_model = db
3203        .set_channel_visibility(channel_id, visibility, session.user_id())
3204        .await?;
3205    let root_id = channel_model.root_id();
3206    let channel = Channel::from_model(channel_model);
3207
3208    let mut connection_pool = session.connection_pool().await;
3209    for (user_id, role) in connection_pool
3210        .channel_user_ids(root_id)
3211        .collect::<Vec<_>>()
3212        .into_iter()
3213    {
3214        let update = if role.can_see_channel(channel.visibility) {
3215            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3216            proto::UpdateChannels {
3217                channels: vec![channel.to_proto()],
3218                ..Default::default()
3219            }
3220        } else {
3221            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3222            proto::UpdateChannels {
3223                delete_channels: vec![channel.id.to_proto()],
3224                ..Default::default()
3225            }
3226        };
3227
3228        for connection_id in connection_pool.user_connection_ids(user_id) {
3229            session.peer.send(connection_id, update.clone())?;
3230        }
3231    }
3232
3233    response.send(proto::Ack {})?;
3234    Ok(())
3235}
3236
3237/// Alter the role for a user in the channel.
3238async fn set_channel_member_role(
3239    request: proto::SetChannelMemberRole,
3240    response: Response<proto::SetChannelMemberRole>,
3241    session: UserSession,
3242) -> Result<()> {
3243    let db = session.db().await;
3244    let channel_id = ChannelId::from_proto(request.channel_id);
3245    let member_id = UserId::from_proto(request.user_id);
3246    let result = db
3247        .set_channel_member_role(
3248            channel_id,
3249            session.user_id(),
3250            member_id,
3251            request.role().into(),
3252        )
3253        .await?;
3254
3255    match result {
3256        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3257            let mut connection_pool = session.connection_pool().await;
3258            notify_membership_updated(
3259                &mut connection_pool,
3260                membership_update,
3261                member_id,
3262                &session.peer,
3263            )
3264        }
3265        db::SetMemberRoleResult::InviteUpdated(channel) => {
3266            let update = proto::UpdateChannels {
3267                channel_invitations: vec![channel.to_proto()],
3268                ..Default::default()
3269            };
3270
3271            for connection_id in session
3272                .connection_pool()
3273                .await
3274                .user_connection_ids(member_id)
3275            {
3276                session.peer.send(connection_id, update.clone())?;
3277            }
3278        }
3279    }
3280
3281    response.send(proto::Ack {})?;
3282    Ok(())
3283}
3284
3285/// Change the name of a channel
3286async fn rename_channel(
3287    request: proto::RenameChannel,
3288    response: Response<proto::RenameChannel>,
3289    session: UserSession,
3290) -> Result<()> {
3291    let db = session.db().await;
3292    let channel_id = ChannelId::from_proto(request.channel_id);
3293    let channel_model = db
3294        .rename_channel(channel_id, session.user_id(), &request.name)
3295        .await?;
3296    let root_id = channel_model.root_id();
3297    let channel = Channel::from_model(channel_model);
3298
3299    response.send(proto::RenameChannelResponse {
3300        channel: Some(channel.to_proto()),
3301    })?;
3302
3303    let connection_pool = session.connection_pool().await;
3304    let update = proto::UpdateChannels {
3305        channels: vec![channel.to_proto()],
3306        ..Default::default()
3307    };
3308    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3309        if role.can_see_channel(channel.visibility) {
3310            session.peer.send(connection_id, update.clone())?;
3311        }
3312    }
3313
3314    Ok(())
3315}
3316
3317/// Move a channel to a new parent.
3318async fn move_channel(
3319    request: proto::MoveChannel,
3320    response: Response<proto::MoveChannel>,
3321    session: UserSession,
3322) -> Result<()> {
3323    let channel_id = ChannelId::from_proto(request.channel_id);
3324    let to = ChannelId::from_proto(request.to);
3325
3326    let (root_id, channels) = session
3327        .db()
3328        .await
3329        .move_channel(channel_id, to, session.user_id())
3330        .await?;
3331
3332    let connection_pool = session.connection_pool().await;
3333    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3334        let channels = channels
3335            .iter()
3336            .filter_map(|channel| {
3337                if role.can_see_channel(channel.visibility) {
3338                    Some(channel.to_proto())
3339                } else {
3340                    None
3341                }
3342            })
3343            .collect::<Vec<_>>();
3344        if channels.is_empty() {
3345            continue;
3346        }
3347
3348        let update = proto::UpdateChannels {
3349            channels,
3350            ..Default::default()
3351        };
3352
3353        session.peer.send(connection_id, update.clone())?;
3354    }
3355
3356    response.send(Ack {})?;
3357    Ok(())
3358}
3359
3360/// Get the list of channel members
3361async fn get_channel_members(
3362    request: proto::GetChannelMembers,
3363    response: Response<proto::GetChannelMembers>,
3364    session: UserSession,
3365) -> Result<()> {
3366    let db = session.db().await;
3367    let channel_id = ChannelId::from_proto(request.channel_id);
3368    let members = db
3369        .get_channel_participant_details(channel_id, session.user_id())
3370        .await?;
3371    response.send(proto::GetChannelMembersResponse { members })?;
3372    Ok(())
3373}
3374
3375/// Accept or decline a channel invitation.
3376async fn respond_to_channel_invite(
3377    request: proto::RespondToChannelInvite,
3378    response: Response<proto::RespondToChannelInvite>,
3379    session: UserSession,
3380) -> Result<()> {
3381    let db = session.db().await;
3382    let channel_id = ChannelId::from_proto(request.channel_id);
3383    let RespondToChannelInvite {
3384        membership_update,
3385        notifications,
3386    } = db
3387        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3388        .await?;
3389
3390    let mut connection_pool = session.connection_pool().await;
3391    if let Some(membership_update) = membership_update {
3392        notify_membership_updated(
3393            &mut connection_pool,
3394            membership_update,
3395            session.user_id(),
3396            &session.peer,
3397        );
3398    } else {
3399        let update = proto::UpdateChannels {
3400            remove_channel_invitations: vec![channel_id.to_proto()],
3401            ..Default::default()
3402        };
3403
3404        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3405            session.peer.send(connection_id, update.clone())?;
3406        }
3407    };
3408
3409    send_notifications(&connection_pool, &session.peer, notifications);
3410
3411    response.send(proto::Ack {})?;
3412
3413    Ok(())
3414}
3415
3416/// Join the channels' room
3417async fn join_channel(
3418    request: proto::JoinChannel,
3419    response: Response<proto::JoinChannel>,
3420    session: UserSession,
3421) -> Result<()> {
3422    let channel_id = ChannelId::from_proto(request.channel_id);
3423    join_channel_internal(channel_id, Box::new(response), session).await
3424}
3425
3426trait JoinChannelInternalResponse {
3427    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3428}
3429impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3430    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3431        Response::<proto::JoinChannel>::send(self, result)
3432    }
3433}
3434impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3435    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3436        Response::<proto::JoinRoom>::send(self, result)
3437    }
3438}
3439
3440async fn join_channel_internal(
3441    channel_id: ChannelId,
3442    response: Box<impl JoinChannelInternalResponse>,
3443    session: UserSession,
3444) -> Result<()> {
3445    let joined_room = {
3446        let mut db = session.db().await;
3447        // If zed quits without leaving the room, and the user re-opens zed before the
3448        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3449        // room they were in.
3450        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3451            tracing::info!(
3452                stale_connection_id = %connection,
3453                "cleaning up stale connection",
3454            );
3455            drop(db);
3456            leave_room_for_session(&session, connection).await?;
3457            db = session.db().await;
3458        }
3459
3460        let (joined_room, membership_updated, role) = db
3461            .join_channel(channel_id, session.user_id(), session.connection_id)
3462            .await?;
3463
3464        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3465            let (can_publish, token) = if role == ChannelRole::Guest {
3466                (
3467                    false,
3468                    live_kit
3469                        .guest_token(
3470                            &joined_room.room.live_kit_room,
3471                            &session.user_id().to_string(),
3472                        )
3473                        .trace_err()?,
3474                )
3475            } else {
3476                (
3477                    true,
3478                    live_kit
3479                        .room_token(
3480                            &joined_room.room.live_kit_room,
3481                            &session.user_id().to_string(),
3482                        )
3483                        .trace_err()?,
3484                )
3485            };
3486
3487            Some(LiveKitConnectionInfo {
3488                server_url: live_kit.url().into(),
3489                token,
3490                can_publish,
3491            })
3492        });
3493
3494        response.send(proto::JoinRoomResponse {
3495            room: Some(joined_room.room.clone()),
3496            channel_id: joined_room
3497                .channel
3498                .as_ref()
3499                .map(|channel| channel.id.to_proto()),
3500            live_kit_connection_info,
3501        })?;
3502
3503        let mut connection_pool = session.connection_pool().await;
3504        if let Some(membership_updated) = membership_updated {
3505            notify_membership_updated(
3506                &mut connection_pool,
3507                membership_updated,
3508                session.user_id(),
3509                &session.peer,
3510            );
3511        }
3512
3513        room_updated(&joined_room.room, &session.peer);
3514
3515        joined_room
3516    };
3517
3518    channel_updated(
3519        &joined_room
3520            .channel
3521            .ok_or_else(|| anyhow!("channel not returned"))?,
3522        &joined_room.room,
3523        &session.peer,
3524        &*session.connection_pool().await,
3525    );
3526
3527    update_user_contacts(session.user_id(), &session).await?;
3528    Ok(())
3529}
3530
3531/// Start editing the channel notes
3532async fn join_channel_buffer(
3533    request: proto::JoinChannelBuffer,
3534    response: Response<proto::JoinChannelBuffer>,
3535    session: UserSession,
3536) -> Result<()> {
3537    let db = session.db().await;
3538    let channel_id = ChannelId::from_proto(request.channel_id);
3539
3540    let open_response = db
3541        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3542        .await?;
3543
3544    let collaborators = open_response.collaborators.clone();
3545    response.send(open_response)?;
3546
3547    let update = UpdateChannelBufferCollaborators {
3548        channel_id: channel_id.to_proto(),
3549        collaborators: collaborators.clone(),
3550    };
3551    channel_buffer_updated(
3552        session.connection_id,
3553        collaborators
3554            .iter()
3555            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3556        &update,
3557        &session.peer,
3558    );
3559
3560    Ok(())
3561}
3562
3563/// Edit the channel notes
3564async fn update_channel_buffer(
3565    request: proto::UpdateChannelBuffer,
3566    session: UserSession,
3567) -> Result<()> {
3568    let db = session.db().await;
3569    let channel_id = ChannelId::from_proto(request.channel_id);
3570
3571    let (collaborators, non_collaborators, epoch, version) = db
3572        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3573        .await?;
3574
3575    channel_buffer_updated(
3576        session.connection_id,
3577        collaborators,
3578        &proto::UpdateChannelBuffer {
3579            channel_id: channel_id.to_proto(),
3580            operations: request.operations,
3581        },
3582        &session.peer,
3583    );
3584
3585    let pool = &*session.connection_pool().await;
3586
3587    broadcast(
3588        None,
3589        non_collaborators
3590            .iter()
3591            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3592        |peer_id| {
3593            session.peer.send(
3594                peer_id,
3595                proto::UpdateChannels {
3596                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3597                        channel_id: channel_id.to_proto(),
3598                        epoch: epoch as u64,
3599                        version: version.clone(),
3600                    }],
3601                    ..Default::default()
3602                },
3603            )
3604        },
3605    );
3606
3607    Ok(())
3608}
3609
3610/// Rejoin the channel notes after a connection blip
3611async fn rejoin_channel_buffers(
3612    request: proto::RejoinChannelBuffers,
3613    response: Response<proto::RejoinChannelBuffers>,
3614    session: UserSession,
3615) -> Result<()> {
3616    let db = session.db().await;
3617    let buffers = db
3618        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3619        .await?;
3620
3621    for rejoined_buffer in &buffers {
3622        let collaborators_to_notify = rejoined_buffer
3623            .buffer
3624            .collaborators
3625            .iter()
3626            .filter_map(|c| Some(c.peer_id?.into()));
3627        channel_buffer_updated(
3628            session.connection_id,
3629            collaborators_to_notify,
3630            &proto::UpdateChannelBufferCollaborators {
3631                channel_id: rejoined_buffer.buffer.channel_id,
3632                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3633            },
3634            &session.peer,
3635        );
3636    }
3637
3638    response.send(proto::RejoinChannelBuffersResponse {
3639        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3640    })?;
3641
3642    Ok(())
3643}
3644
3645/// Stop editing the channel notes
3646async fn leave_channel_buffer(
3647    request: proto::LeaveChannelBuffer,
3648    response: Response<proto::LeaveChannelBuffer>,
3649    session: UserSession,
3650) -> Result<()> {
3651    let db = session.db().await;
3652    let channel_id = ChannelId::from_proto(request.channel_id);
3653
3654    let left_buffer = db
3655        .leave_channel_buffer(channel_id, session.connection_id)
3656        .await?;
3657
3658    response.send(Ack {})?;
3659
3660    channel_buffer_updated(
3661        session.connection_id,
3662        left_buffer.connections,
3663        &proto::UpdateChannelBufferCollaborators {
3664            channel_id: channel_id.to_proto(),
3665            collaborators: left_buffer.collaborators,
3666        },
3667        &session.peer,
3668    );
3669
3670    Ok(())
3671}
3672
3673fn channel_buffer_updated<T: EnvelopedMessage>(
3674    sender_id: ConnectionId,
3675    collaborators: impl IntoIterator<Item = ConnectionId>,
3676    message: &T,
3677    peer: &Peer,
3678) {
3679    broadcast(Some(sender_id), collaborators, |peer_id| {
3680        peer.send(peer_id, message.clone())
3681    });
3682}
3683
3684fn send_notifications(
3685    connection_pool: &ConnectionPool,
3686    peer: &Peer,
3687    notifications: db::NotificationBatch,
3688) {
3689    for (user_id, notification) in notifications {
3690        for connection_id in connection_pool.user_connection_ids(user_id) {
3691            if let Err(error) = peer.send(
3692                connection_id,
3693                proto::AddNotification {
3694                    notification: Some(notification.clone()),
3695                },
3696            ) {
3697                tracing::error!(
3698                    "failed to send notification to {:?} {}",
3699                    connection_id,
3700                    error
3701                );
3702            }
3703        }
3704    }
3705}
3706
3707/// Send a message to the channel
3708async fn send_channel_message(
3709    request: proto::SendChannelMessage,
3710    response: Response<proto::SendChannelMessage>,
3711    session: UserSession,
3712) -> Result<()> {
3713    // Validate the message body.
3714    let body = request.body.trim().to_string();
3715    if body.len() > MAX_MESSAGE_LEN {
3716        return Err(anyhow!("message is too long"))?;
3717    }
3718    if body.is_empty() {
3719        return Err(anyhow!("message can't be blank"))?;
3720    }
3721
3722    // TODO: adjust mentions if body is trimmed
3723
3724    let timestamp = OffsetDateTime::now_utc();
3725    let nonce = request
3726        .nonce
3727        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3728
3729    let channel_id = ChannelId::from_proto(request.channel_id);
3730    let CreatedChannelMessage {
3731        message_id,
3732        participant_connection_ids,
3733        channel_members,
3734        notifications,
3735    } = session
3736        .db()
3737        .await
3738        .create_channel_message(
3739            channel_id,
3740            session.user_id(),
3741            &body,
3742            &request.mentions,
3743            timestamp,
3744            nonce.clone().into(),
3745            match request.reply_to_message_id {
3746                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3747                None => None,
3748            },
3749        )
3750        .await?;
3751
3752    let message = proto::ChannelMessage {
3753        sender_id: session.user_id().to_proto(),
3754        id: message_id.to_proto(),
3755        body,
3756        mentions: request.mentions,
3757        timestamp: timestamp.unix_timestamp() as u64,
3758        nonce: Some(nonce),
3759        reply_to_message_id: request.reply_to_message_id,
3760        edited_at: None,
3761    };
3762    broadcast(
3763        Some(session.connection_id),
3764        participant_connection_ids,
3765        |connection| {
3766            session.peer.send(
3767                connection,
3768                proto::ChannelMessageSent {
3769                    channel_id: channel_id.to_proto(),
3770                    message: Some(message.clone()),
3771                },
3772            )
3773        },
3774    );
3775    response.send(proto::SendChannelMessageResponse {
3776        message: Some(message),
3777    })?;
3778
3779    let pool = &*session.connection_pool().await;
3780    broadcast(
3781        None,
3782        channel_members
3783            .iter()
3784            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3785        |peer_id| {
3786            session.peer.send(
3787                peer_id,
3788                proto::UpdateChannels {
3789                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3790                        channel_id: channel_id.to_proto(),
3791                        message_id: message_id.to_proto(),
3792                    }],
3793                    ..Default::default()
3794                },
3795            )
3796        },
3797    );
3798    send_notifications(pool, &session.peer, notifications);
3799
3800    Ok(())
3801}
3802
3803/// Delete a channel message
3804async fn remove_channel_message(
3805    request: proto::RemoveChannelMessage,
3806    response: Response<proto::RemoveChannelMessage>,
3807    session: UserSession,
3808) -> Result<()> {
3809    let channel_id = ChannelId::from_proto(request.channel_id);
3810    let message_id = MessageId::from_proto(request.message_id);
3811    let (connection_ids, existing_notification_ids) = session
3812        .db()
3813        .await
3814        .remove_channel_message(channel_id, message_id, session.user_id())
3815        .await?;
3816
3817    broadcast(
3818        Some(session.connection_id),
3819        connection_ids,
3820        move |connection| {
3821            session.peer.send(connection, request.clone())?;
3822
3823            for notification_id in &existing_notification_ids {
3824                session.peer.send(
3825                    connection,
3826                    proto::DeleteNotification {
3827                        notification_id: (*notification_id).to_proto(),
3828                    },
3829                )?;
3830            }
3831
3832            Ok(())
3833        },
3834    );
3835    response.send(proto::Ack {})?;
3836    Ok(())
3837}
3838
3839async fn update_channel_message(
3840    request: proto::UpdateChannelMessage,
3841    response: Response<proto::UpdateChannelMessage>,
3842    session: UserSession,
3843) -> Result<()> {
3844    let channel_id = ChannelId::from_proto(request.channel_id);
3845    let message_id = MessageId::from_proto(request.message_id);
3846    let updated_at = OffsetDateTime::now_utc();
3847    let UpdatedChannelMessage {
3848        message_id,
3849        participant_connection_ids,
3850        notifications,
3851        reply_to_message_id,
3852        timestamp,
3853        deleted_mention_notification_ids,
3854        updated_mention_notifications,
3855    } = session
3856        .db()
3857        .await
3858        .update_channel_message(
3859            channel_id,
3860            message_id,
3861            session.user_id(),
3862            request.body.as_str(),
3863            &request.mentions,
3864            updated_at,
3865        )
3866        .await?;
3867
3868    let nonce = request
3869        .nonce
3870        .clone()
3871        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3872
3873    let message = proto::ChannelMessage {
3874        sender_id: session.user_id().to_proto(),
3875        id: message_id.to_proto(),
3876        body: request.body.clone(),
3877        mentions: request.mentions.clone(),
3878        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3879        nonce: Some(nonce),
3880        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3881        edited_at: Some(updated_at.unix_timestamp() as u64),
3882    };
3883
3884    response.send(proto::Ack {})?;
3885
3886    let pool = &*session.connection_pool().await;
3887    broadcast(
3888        Some(session.connection_id),
3889        participant_connection_ids,
3890        |connection| {
3891            session.peer.send(
3892                connection,
3893                proto::ChannelMessageUpdate {
3894                    channel_id: channel_id.to_proto(),
3895                    message: Some(message.clone()),
3896                },
3897            )?;
3898
3899            for notification_id in &deleted_mention_notification_ids {
3900                session.peer.send(
3901                    connection,
3902                    proto::DeleteNotification {
3903                        notification_id: (*notification_id).to_proto(),
3904                    },
3905                )?;
3906            }
3907
3908            for notification in &updated_mention_notifications {
3909                session.peer.send(
3910                    connection,
3911                    proto::UpdateNotification {
3912                        notification: Some(notification.clone()),
3913                    },
3914                )?;
3915            }
3916
3917            Ok(())
3918        },
3919    );
3920
3921    send_notifications(pool, &session.peer, notifications);
3922
3923    Ok(())
3924}
3925
3926/// Mark a channel message as read
3927async fn acknowledge_channel_message(
3928    request: proto::AckChannelMessage,
3929    session: UserSession,
3930) -> Result<()> {
3931    let channel_id = ChannelId::from_proto(request.channel_id);
3932    let message_id = MessageId::from_proto(request.message_id);
3933    let notifications = session
3934        .db()
3935        .await
3936        .observe_channel_message(channel_id, session.user_id(), message_id)
3937        .await?;
3938    send_notifications(
3939        &*session.connection_pool().await,
3940        &session.peer,
3941        notifications,
3942    );
3943    Ok(())
3944}
3945
3946/// Mark a buffer version as synced
3947async fn acknowledge_buffer_version(
3948    request: proto::AckBufferOperation,
3949    session: UserSession,
3950) -> Result<()> {
3951    let buffer_id = BufferId::from_proto(request.buffer_id);
3952    session
3953        .db()
3954        .await
3955        .observe_buffer_version(
3956            buffer_id,
3957            session.user_id(),
3958            request.epoch as i32,
3959            &request.version,
3960        )
3961        .await?;
3962    Ok(())
3963}
3964
3965struct CompleteWithLanguageModelRateLimit;
3966
3967impl RateLimit for CompleteWithLanguageModelRateLimit {
3968    fn capacity() -> usize {
3969        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3970            .ok()
3971            .and_then(|v| v.parse().ok())
3972            .unwrap_or(120) // Picked arbitrarily
3973    }
3974
3975    fn refill_duration() -> chrono::Duration {
3976        chrono::Duration::hours(1)
3977    }
3978
3979    fn db_name() -> &'static str {
3980        "complete-with-language-model"
3981    }
3982}
3983
3984async fn complete_with_language_model(
3985    request: proto::CompleteWithLanguageModel,
3986    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3987    session: Session,
3988    open_ai_api_key: Option<Arc<str>>,
3989    google_ai_api_key: Option<Arc<str>>,
3990    anthropic_api_key: Option<Arc<str>>,
3991) -> Result<()> {
3992    let Some(session) = session.for_user() else {
3993        return Err(anyhow!("user not found"))?;
3994    };
3995    authorize_access_to_language_models(&session).await?;
3996    session
3997        .rate_limiter
3998        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
3999        .await?;
4000
4001    if request.model.starts_with("gpt") {
4002        let api_key =
4003            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4004        complete_with_open_ai(request, response, session, api_key).await?;
4005    } else if request.model.starts_with("gemini") {
4006        let api_key = google_ai_api_key
4007            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4008        complete_with_google_ai(request, response, session, api_key).await?;
4009    } else if request.model.starts_with("claude") {
4010        let api_key = anthropic_api_key
4011            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4012        complete_with_anthropic(request, response, session, api_key).await?;
4013    }
4014
4015    Ok(())
4016}
4017
4018async fn complete_with_open_ai(
4019    request: proto::CompleteWithLanguageModel,
4020    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4021    session: UserSession,
4022    api_key: Arc<str>,
4023) -> Result<()> {
4024    const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
4025
4026    let mut completion_stream = open_ai::stream_completion(
4027        &session.http_client,
4028        OPEN_AI_API_URL,
4029        &api_key,
4030        crate::ai::language_model_request_to_open_ai(request)?,
4031    )
4032    .await
4033    .context("open_ai::stream_completion request failed")?;
4034
4035    while let Some(event) = completion_stream.next().await {
4036        let event = event?;
4037        response.send(proto::LanguageModelResponse {
4038            choices: event
4039                .choices
4040                .into_iter()
4041                .map(|choice| proto::LanguageModelChoiceDelta {
4042                    index: choice.index,
4043                    delta: Some(proto::LanguageModelResponseMessage {
4044                        role: choice.delta.role.map(|role| match role {
4045                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4046                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4047                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4048                        } as i32),
4049                        content: choice.delta.content,
4050                    }),
4051                    finish_reason: choice.finish_reason,
4052                })
4053                .collect(),
4054        })?;
4055    }
4056
4057    Ok(())
4058}
4059
4060async fn complete_with_google_ai(
4061    request: proto::CompleteWithLanguageModel,
4062    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4063    session: UserSession,
4064    api_key: Arc<str>,
4065) -> Result<()> {
4066    let mut stream = google_ai::stream_generate_content(
4067        &session.http_client,
4068        google_ai::API_URL,
4069        api_key.as_ref(),
4070        crate::ai::language_model_request_to_google_ai(request)?,
4071    )
4072    .await
4073    .context("google_ai::stream_generate_content request failed")?;
4074
4075    while let Some(event) = stream.next().await {
4076        let event = event?;
4077        response.send(proto::LanguageModelResponse {
4078            choices: event
4079                .candidates
4080                .unwrap_or_default()
4081                .into_iter()
4082                .map(|candidate| proto::LanguageModelChoiceDelta {
4083                    index: candidate.index as u32,
4084                    delta: Some(proto::LanguageModelResponseMessage {
4085                        role: Some(match candidate.content.role {
4086                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4087                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4088                        } as i32),
4089                        content: Some(
4090                            candidate
4091                                .content
4092                                .parts
4093                                .into_iter()
4094                                .filter_map(|part| match part {
4095                                    google_ai::Part::TextPart(part) => Some(part.text),
4096                                    google_ai::Part::InlineDataPart(_) => None,
4097                                })
4098                                .collect(),
4099                        ),
4100                    }),
4101                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4102                })
4103                .collect(),
4104        })?;
4105    }
4106
4107    Ok(())
4108}
4109
4110async fn complete_with_anthropic(
4111    request: proto::CompleteWithLanguageModel,
4112    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4113    session: UserSession,
4114    api_key: Arc<str>,
4115) -> Result<()> {
4116    let model = anthropic::Model::from_id(&request.model)?;
4117
4118    let mut system_message = String::new();
4119    let messages = request
4120        .messages
4121        .into_iter()
4122        .filter_map(|message| match message.role() {
4123            LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4124                role: anthropic::Role::User,
4125                content: message.content,
4126            }),
4127            LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4128                role: anthropic::Role::Assistant,
4129                content: message.content,
4130            }),
4131            // Anthropic's API breaks system instructions out as a separate field rather
4132            // than having a system message role.
4133            LanguageModelRole::LanguageModelSystem => {
4134                if !system_message.is_empty() {
4135                    system_message.push_str("\n\n");
4136                }
4137                system_message.push_str(&message.content);
4138
4139                None
4140            }
4141        })
4142        .collect();
4143
4144    let mut stream = anthropic::stream_completion(
4145        &session.http_client,
4146        "https://api.anthropic.com",
4147        &api_key,
4148        anthropic::Request {
4149            model,
4150            messages,
4151            stream: true,
4152            system: system_message,
4153            max_tokens: 4092,
4154        },
4155    )
4156    .await?;
4157
4158    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4159
4160    while let Some(event) = stream.next().await {
4161        let event = event?;
4162
4163        match event {
4164            anthropic::ResponseEvent::MessageStart { message } => {
4165                if let Some(role) = message.role {
4166                    if role == "assistant" {
4167                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
4168                    } else if role == "user" {
4169                        current_role = proto::LanguageModelRole::LanguageModelUser;
4170                    }
4171                }
4172            }
4173            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4174                match content_block {
4175                    anthropic::ContentBlock::Text { text } => {
4176                        if !text.is_empty() {
4177                            response.send(proto::LanguageModelResponse {
4178                                choices: vec![proto::LanguageModelChoiceDelta {
4179                                    index: 0,
4180                                    delta: Some(proto::LanguageModelResponseMessage {
4181                                        role: Some(current_role as i32),
4182                                        content: Some(text),
4183                                    }),
4184                                    finish_reason: None,
4185                                }],
4186                            })?;
4187                        }
4188                    }
4189                }
4190            }
4191            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4192                anthropic::TextDelta::TextDelta { text } => {
4193                    response.send(proto::LanguageModelResponse {
4194                        choices: vec![proto::LanguageModelChoiceDelta {
4195                            index: 0,
4196                            delta: Some(proto::LanguageModelResponseMessage {
4197                                role: Some(current_role as i32),
4198                                content: Some(text),
4199                            }),
4200                            finish_reason: None,
4201                        }],
4202                    })?;
4203                }
4204            },
4205            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4206                if let Some(stop_reason) = delta.stop_reason {
4207                    response.send(proto::LanguageModelResponse {
4208                        choices: vec![proto::LanguageModelChoiceDelta {
4209                            index: 0,
4210                            delta: None,
4211                            finish_reason: Some(stop_reason),
4212                        }],
4213                    })?;
4214                }
4215            }
4216            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4217            anthropic::ResponseEvent::MessageStop {} => {}
4218            anthropic::ResponseEvent::Ping {} => {}
4219        }
4220    }
4221
4222    Ok(())
4223}
4224
4225struct CountTokensWithLanguageModelRateLimit;
4226
4227impl RateLimit for CountTokensWithLanguageModelRateLimit {
4228    fn capacity() -> usize {
4229        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4230            .ok()
4231            .and_then(|v| v.parse().ok())
4232            .unwrap_or(600) // Picked arbitrarily
4233    }
4234
4235    fn refill_duration() -> chrono::Duration {
4236        chrono::Duration::hours(1)
4237    }
4238
4239    fn db_name() -> &'static str {
4240        "count-tokens-with-language-model"
4241    }
4242}
4243
4244async fn count_tokens_with_language_model(
4245    request: proto::CountTokensWithLanguageModel,
4246    response: Response<proto::CountTokensWithLanguageModel>,
4247    session: UserSession,
4248    google_ai_api_key: Option<Arc<str>>,
4249) -> Result<()> {
4250    authorize_access_to_language_models(&session).await?;
4251
4252    if !request.model.starts_with("gemini") {
4253        return Err(anyhow!(
4254            "counting tokens for model: {:?} is not supported",
4255            request.model
4256        ))?;
4257    }
4258
4259    session
4260        .rate_limiter
4261        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4262        .await?;
4263
4264    let api_key = google_ai_api_key
4265        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4266    let tokens_response = google_ai::count_tokens(
4267        &session.http_client,
4268        google_ai::API_URL,
4269        &api_key,
4270        crate::ai::count_tokens_request_to_google_ai(request)?,
4271    )
4272    .await?;
4273    response.send(proto::CountTokensResponse {
4274        token_count: tokens_response.total_tokens as u32,
4275    })?;
4276    Ok(())
4277}
4278
4279async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4280    let db = session.db().await;
4281    let flags = db.get_user_flags(session.user_id()).await?;
4282    if flags.iter().any(|flag| flag == "language-models") {
4283        Ok(())
4284    } else {
4285        Err(anyhow!("permission denied"))?
4286    }
4287}
4288
4289/// Start receiving chat updates for a channel
4290async fn join_channel_chat(
4291    request: proto::JoinChannelChat,
4292    response: Response<proto::JoinChannelChat>,
4293    session: UserSession,
4294) -> Result<()> {
4295    let channel_id = ChannelId::from_proto(request.channel_id);
4296
4297    let db = session.db().await;
4298    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4299        .await?;
4300    let messages = db
4301        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4302        .await?;
4303    response.send(proto::JoinChannelChatResponse {
4304        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4305        messages,
4306    })?;
4307    Ok(())
4308}
4309
4310/// Stop receiving chat updates for a channel
4311async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4312    let channel_id = ChannelId::from_proto(request.channel_id);
4313    session
4314        .db()
4315        .await
4316        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4317        .await?;
4318    Ok(())
4319}
4320
4321/// Retrieve the chat history for a channel
4322async fn get_channel_messages(
4323    request: proto::GetChannelMessages,
4324    response: Response<proto::GetChannelMessages>,
4325    session: UserSession,
4326) -> Result<()> {
4327    let channel_id = ChannelId::from_proto(request.channel_id);
4328    let messages = session
4329        .db()
4330        .await
4331        .get_channel_messages(
4332            channel_id,
4333            session.user_id(),
4334            MESSAGE_COUNT_PER_PAGE,
4335            Some(MessageId::from_proto(request.before_message_id)),
4336        )
4337        .await?;
4338    response.send(proto::GetChannelMessagesResponse {
4339        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4340        messages,
4341    })?;
4342    Ok(())
4343}
4344
4345/// Retrieve specific chat messages
4346async fn get_channel_messages_by_id(
4347    request: proto::GetChannelMessagesById,
4348    response: Response<proto::GetChannelMessagesById>,
4349    session: UserSession,
4350) -> Result<()> {
4351    let message_ids = request
4352        .message_ids
4353        .iter()
4354        .map(|id| MessageId::from_proto(*id))
4355        .collect::<Vec<_>>();
4356    let messages = session
4357        .db()
4358        .await
4359        .get_channel_messages_by_id(session.user_id(), &message_ids)
4360        .await?;
4361    response.send(proto::GetChannelMessagesResponse {
4362        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4363        messages,
4364    })?;
4365    Ok(())
4366}
4367
4368/// Retrieve the current users notifications
4369async fn get_notifications(
4370    request: proto::GetNotifications,
4371    response: Response<proto::GetNotifications>,
4372    session: UserSession,
4373) -> Result<()> {
4374    let notifications = session
4375        .db()
4376        .await
4377        .get_notifications(
4378            session.user_id(),
4379            NOTIFICATION_COUNT_PER_PAGE,
4380            request
4381                .before_id
4382                .map(|id| db::NotificationId::from_proto(id)),
4383        )
4384        .await?;
4385    response.send(proto::GetNotificationsResponse {
4386        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4387        notifications,
4388    })?;
4389    Ok(())
4390}
4391
4392/// Mark notifications as read
4393async fn mark_notification_as_read(
4394    request: proto::MarkNotificationRead,
4395    response: Response<proto::MarkNotificationRead>,
4396    session: UserSession,
4397) -> Result<()> {
4398    let database = &session.db().await;
4399    let notifications = database
4400        .mark_notification_as_read_by_id(
4401            session.user_id(),
4402            NotificationId::from_proto(request.notification_id),
4403        )
4404        .await?;
4405    send_notifications(
4406        &*session.connection_pool().await,
4407        &session.peer,
4408        notifications,
4409    );
4410    response.send(proto::Ack {})?;
4411    Ok(())
4412}
4413
4414/// Get the current users information
4415async fn get_private_user_info(
4416    _request: proto::GetPrivateUserInfo,
4417    response: Response<proto::GetPrivateUserInfo>,
4418    session: UserSession,
4419) -> Result<()> {
4420    let db = session.db().await;
4421
4422    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4423    let user = db
4424        .get_user_by_id(session.user_id())
4425        .await?
4426        .ok_or_else(|| anyhow!("user not found"))?;
4427    let flags = db.get_user_flags(session.user_id()).await?;
4428
4429    response.send(proto::GetPrivateUserInfoResponse {
4430        metrics_id,
4431        staff: user.admin,
4432        flags,
4433    })?;
4434    Ok(())
4435}
4436
4437fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4438    match message {
4439        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4440        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4441        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4442        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4443        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4444            code: frame.code.into(),
4445            reason: frame.reason,
4446        })),
4447    }
4448}
4449
4450fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4451    match message {
4452        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4453        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4454        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4455        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4456        AxumMessage::Close(frame) => {
4457            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4458                code: frame.code.into(),
4459                reason: frame.reason,
4460            }))
4461        }
4462    }
4463}
4464
4465fn notify_membership_updated(
4466    connection_pool: &mut ConnectionPool,
4467    result: MembershipUpdated,
4468    user_id: UserId,
4469    peer: &Peer,
4470) {
4471    for membership in &result.new_channels.channel_memberships {
4472        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4473    }
4474    for channel_id in &result.removed_channels {
4475        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4476    }
4477
4478    let user_channels_update = proto::UpdateUserChannels {
4479        channel_memberships: result
4480            .new_channels
4481            .channel_memberships
4482            .iter()
4483            .map(|cm| proto::ChannelMembership {
4484                channel_id: cm.channel_id.to_proto(),
4485                role: cm.role.into(),
4486            })
4487            .collect(),
4488        ..Default::default()
4489    };
4490
4491    let mut update = build_channels_update(result.new_channels, vec![], connection_pool);
4492    update.delete_channels = result
4493        .removed_channels
4494        .into_iter()
4495        .map(|id| id.to_proto())
4496        .collect();
4497    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4498
4499    for connection_id in connection_pool.user_connection_ids(user_id) {
4500        peer.send(connection_id, user_channels_update.clone())
4501            .trace_err();
4502        peer.send(connection_id, update.clone()).trace_err();
4503    }
4504}
4505
4506fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4507    proto::UpdateUserChannels {
4508        channel_memberships: channels
4509            .channel_memberships
4510            .iter()
4511            .map(|m| proto::ChannelMembership {
4512                channel_id: m.channel_id.to_proto(),
4513                role: m.role.into(),
4514            })
4515            .collect(),
4516        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4517        observed_channel_message_id: channels.observed_channel_messages.clone(),
4518    }
4519}
4520
4521fn build_channels_update(
4522    channels: ChannelsForUser,
4523    channel_invites: Vec<db::Channel>,
4524    pool: &ConnectionPool,
4525) -> proto::UpdateChannels {
4526    let mut update = proto::UpdateChannels::default();
4527
4528    for channel in channels.channels {
4529        update.channels.push(channel.to_proto());
4530    }
4531
4532    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4533    update.latest_channel_message_ids = channels.latest_channel_messages;
4534
4535    for (channel_id, participants) in channels.channel_participants {
4536        update
4537            .channel_participants
4538            .push(proto::ChannelParticipants {
4539                channel_id: channel_id.to_proto(),
4540                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4541            });
4542    }
4543
4544    for channel in channel_invites {
4545        update.channel_invitations.push(channel.to_proto());
4546    }
4547
4548    update.hosted_projects = channels.hosted_projects;
4549    update.dev_servers = channels
4550        .dev_servers
4551        .into_iter()
4552        .map(|dev_server| dev_server.to_proto(pool.dev_server_status(dev_server.id)))
4553        .collect();
4554    update.remote_projects = channels.remote_projects;
4555
4556    update
4557}
4558
4559fn build_initial_contacts_update(
4560    contacts: Vec<db::Contact>,
4561    pool: &ConnectionPool,
4562) -> proto::UpdateContacts {
4563    let mut update = proto::UpdateContacts::default();
4564
4565    for contact in contacts {
4566        match contact {
4567            db::Contact::Accepted { user_id, busy } => {
4568                update.contacts.push(contact_for_user(user_id, busy, &pool));
4569            }
4570            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4571            db::Contact::Incoming { user_id } => {
4572                update
4573                    .incoming_requests
4574                    .push(proto::IncomingContactRequest {
4575                        requester_id: user_id.to_proto(),
4576                    })
4577            }
4578        }
4579    }
4580
4581    update
4582}
4583
4584fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4585    proto::Contact {
4586        user_id: user_id.to_proto(),
4587        online: pool.is_user_online(user_id),
4588        busy,
4589    }
4590}
4591
4592fn room_updated(room: &proto::Room, peer: &Peer) {
4593    broadcast(
4594        None,
4595        room.participants
4596            .iter()
4597            .filter_map(|participant| Some(participant.peer_id?.into())),
4598        |peer_id| {
4599            peer.send(
4600                peer_id,
4601                proto::RoomUpdated {
4602                    room: Some(room.clone()),
4603                },
4604            )
4605        },
4606    );
4607}
4608
4609fn channel_updated(
4610    channel: &db::channel::Model,
4611    room: &proto::Room,
4612    peer: &Peer,
4613    pool: &ConnectionPool,
4614) {
4615    let participants = room
4616        .participants
4617        .iter()
4618        .map(|p| p.user_id)
4619        .collect::<Vec<_>>();
4620
4621    broadcast(
4622        None,
4623        pool.channel_connection_ids(channel.root_id())
4624            .filter_map(|(channel_id, role)| {
4625                role.can_see_channel(channel.visibility).then(|| channel_id)
4626            }),
4627        |peer_id| {
4628            peer.send(
4629                peer_id,
4630                proto::UpdateChannels {
4631                    channel_participants: vec![proto::ChannelParticipants {
4632                        channel_id: channel.id.to_proto(),
4633                        participant_user_ids: participants.clone(),
4634                    }],
4635                    ..Default::default()
4636                },
4637            )
4638        },
4639    );
4640}
4641
4642async fn update_dev_server_status(
4643    dev_server: &dev_server::Model,
4644    status: proto::DevServerStatus,
4645    session: &Session,
4646) {
4647    let pool = session.connection_pool().await;
4648    let connections = pool.channel_connection_ids(dev_server.channel_id);
4649    for (connection_id, _) in connections {
4650        session
4651            .peer
4652            .send(
4653                connection_id,
4654                proto::UpdateChannels {
4655                    dev_servers: vec![dev_server.to_proto(status)],
4656                    ..Default::default()
4657                },
4658            )
4659            .trace_err();
4660    }
4661}
4662
4663async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4664    let db = session.db().await;
4665
4666    let contacts = db.get_contacts(user_id).await?;
4667    let busy = db.is_user_busy(user_id).await?;
4668
4669    let pool = session.connection_pool().await;
4670    let updated_contact = contact_for_user(user_id, busy, &pool);
4671    for contact in contacts {
4672        if let db::Contact::Accepted {
4673            user_id: contact_user_id,
4674            ..
4675        } = contact
4676        {
4677            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4678                session
4679                    .peer
4680                    .send(
4681                        contact_conn_id,
4682                        proto::UpdateContacts {
4683                            contacts: vec![updated_contact.clone()],
4684                            remove_contacts: Default::default(),
4685                            incoming_requests: Default::default(),
4686                            remove_incoming_requests: Default::default(),
4687                            outgoing_requests: Default::default(),
4688                            remove_outgoing_requests: Default::default(),
4689                        },
4690                    )
4691                    .trace_err();
4692            }
4693        }
4694    }
4695    Ok(())
4696}
4697
4698async fn lost_dev_server_connection(session: &Session) -> Result<()> {
4699    log::info!("lost dev server connection, unsharing projects");
4700    let project_ids = session
4701        .db()
4702        .await
4703        .get_stale_dev_server_projects(session.connection_id)
4704        .await?;
4705
4706    for project_id in project_ids {
4707        // not unshare re-checks the connection ids match, so we get away with no transaction
4708        unshare_project_internal(project_id, &session).await?;
4709    }
4710
4711    Ok(())
4712}
4713
4714async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
4715    let mut contacts_to_update = HashSet::default();
4716
4717    let room_id;
4718    let canceled_calls_to_user_ids;
4719    let live_kit_room;
4720    let delete_live_kit_room;
4721    let room;
4722    let channel;
4723
4724    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4725        contacts_to_update.insert(session.user_id());
4726
4727        for project in left_room.left_projects.values() {
4728            project_left(project, session);
4729        }
4730
4731        room_id = RoomId::from_proto(left_room.room.id);
4732        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4733        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4734        delete_live_kit_room = left_room.deleted;
4735        room = mem::take(&mut left_room.room);
4736        channel = mem::take(&mut left_room.channel);
4737
4738        room_updated(&room, &session.peer);
4739    } else {
4740        return Ok(());
4741    }
4742
4743    if let Some(channel) = channel {
4744        channel_updated(
4745            &channel,
4746            &room,
4747            &session.peer,
4748            &*session.connection_pool().await,
4749        );
4750    }
4751
4752    {
4753        let pool = session.connection_pool().await;
4754        for canceled_user_id in canceled_calls_to_user_ids {
4755            for connection_id in pool.user_connection_ids(canceled_user_id) {
4756                session
4757                    .peer
4758                    .send(
4759                        connection_id,
4760                        proto::CallCanceled {
4761                            room_id: room_id.to_proto(),
4762                        },
4763                    )
4764                    .trace_err();
4765            }
4766            contacts_to_update.insert(canceled_user_id);
4767        }
4768    }
4769
4770    for contact_user_id in contacts_to_update {
4771        update_user_contacts(contact_user_id, &session).await?;
4772    }
4773
4774    if let Some(live_kit) = session.live_kit_client.as_ref() {
4775        live_kit
4776            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4777            .await
4778            .trace_err();
4779
4780        if delete_live_kit_room {
4781            live_kit.delete_room(live_kit_room).await.trace_err();
4782        }
4783    }
4784
4785    Ok(())
4786}
4787
4788async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4789    let left_channel_buffers = session
4790        .db()
4791        .await
4792        .leave_channel_buffers(session.connection_id)
4793        .await?;
4794
4795    for left_buffer in left_channel_buffers {
4796        channel_buffer_updated(
4797            session.connection_id,
4798            left_buffer.connections,
4799            &proto::UpdateChannelBufferCollaborators {
4800                channel_id: left_buffer.channel_id.to_proto(),
4801                collaborators: left_buffer.collaborators,
4802            },
4803            &session.peer,
4804        );
4805    }
4806
4807    Ok(())
4808}
4809
4810fn project_left(project: &db::LeftProject, session: &UserSession) {
4811    for connection_id in &project.connection_ids {
4812        if project.host_user_id == Some(session.user_id()) {
4813            session
4814                .peer
4815                .send(
4816                    *connection_id,
4817                    proto::UnshareProject {
4818                        project_id: project.id.to_proto(),
4819                    },
4820                )
4821                .trace_err();
4822        } else {
4823            session
4824                .peer
4825                .send(
4826                    *connection_id,
4827                    proto::RemoveProjectCollaborator {
4828                        project_id: project.id.to_proto(),
4829                        peer_id: Some(session.connection_id.into()),
4830                    },
4831                )
4832                .trace_err();
4833        }
4834    }
4835}
4836
4837pub trait ResultExt {
4838    type Ok;
4839
4840    fn trace_err(self) -> Option<Self::Ok>;
4841}
4842
4843impl<T, E> ResultExt for Result<T, E>
4844where
4845    E: std::fmt::Debug,
4846{
4847    type Ok = T;
4848
4849    #[track_caller]
4850    fn trace_err(self) -> Option<T> {
4851        match self {
4852            Ok(value) => Some(value),
4853            Err(error) => {
4854                tracing::error!("{:?}", error);
4855                None
4856            }
4857        }
4858    }
4859}