rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::db::billing_subscription::SubscriptionKind;
   5use crate::llm::db::LlmDatabase;
   6use crate::llm::{
   7    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG,
   8    MIN_ACCOUNT_AGE_FOR_LLM_USE,
   9};
  10use crate::{
  11    AppState, Error, Result, auth,
  12    db::{
  13        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  14        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  15        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  16        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  17    },
  18    executor::Executor,
  19};
  20use anyhow::{Context as _, anyhow, bail};
  21use async_tungstenite::tungstenite::{
  22    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  23};
  24use axum::headers::UserAgent;
  25use axum::{
  26    Extension, Router, TypedHeader,
  27    body::Body,
  28    extract::{
  29        ConnectInfo, WebSocketUpgrade,
  30        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  31    },
  32    headers::{Header, HeaderName},
  33    http::StatusCode,
  34    middleware,
  35    response::IntoResponse,
  36    routing::get,
  37};
  38use chrono::Utc;
  39use collections::{HashMap, HashSet};
  40pub use connection_pool::{ConnectionPool, ZedVersion};
  41use core::fmt::{self, Debug, Formatter};
  42use futures::TryFutureExt as _;
  43use reqwest_client::ReqwestClient;
  44use rpc::proto::{MultiLspQuery, split_repository_update};
  45use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  46use tracing::Span;
  47
  48use futures::{
  49    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  50    stream::FuturesUnordered,
  51};
  52use prometheus::{IntGauge, register_int_gauge};
  53use rpc::{
  54    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  55    proto::{
  56        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  57        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  58    },
  59};
  60use semantic_version::SemanticVersion;
  61use serde::{Serialize, Serializer};
  62use std::{
  63    any::TypeId,
  64    future::Future,
  65    marker::PhantomData,
  66    mem,
  67    net::SocketAddr,
  68    ops::{Deref, DerefMut},
  69    rc::Rc,
  70    sync::{
  71        Arc, OnceLock,
  72        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  73    },
  74    time::{Duration, Instant},
  75};
  76use time::OffsetDateTime;
  77use tokio::sync::{Semaphore, watch};
  78use tower::ServiceBuilder;
  79use tracing::{
  80    Instrument,
  81    field::{self},
  82    info_span, instrument,
  83};
  84
  85pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  86
  87// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  88pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  89
  90const MESSAGE_COUNT_PER_PAGE: usize = 100;
  91const MAX_MESSAGE_LEN: usize = 1024;
  92const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  93const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  94
  95static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  96
  97const TOTAL_DURATION_MS: &str = "total_duration_ms";
  98const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  99const QUEUE_DURATION_MS: &str = "queue_duration_ms";
 100const HOST_WAITING_MS: &str = "host_waiting_ms";
 101
 102type MessageHandler =
 103    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
 104
 105pub struct ConnectionGuard;
 106
 107impl ConnectionGuard {
 108    pub fn try_acquire() -> Result<Self, ()> {
 109        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 110        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 111            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 112            tracing::error!(
 113                "too many concurrent connections: {}",
 114                current_connections + 1
 115            );
 116            return Err(());
 117        }
 118        Ok(ConnectionGuard)
 119    }
 120}
 121
 122impl Drop for ConnectionGuard {
 123    fn drop(&mut self) {
 124        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 125    }
 126}
 127
 128struct Response<R> {
 129    peer: Arc<Peer>,
 130    receipt: Receipt<R>,
 131    responded: Arc<AtomicBool>,
 132}
 133
 134impl<R: RequestMessage> Response<R> {
 135    fn send(self, payload: R::Response) -> Result<()> {
 136        self.responded.store(true, SeqCst);
 137        self.peer.respond(self.receipt, payload)?;
 138        Ok(())
 139    }
 140}
 141
 142#[derive(Clone, Debug)]
 143pub enum Principal {
 144    User(User),
 145    Impersonated { user: User, admin: User },
 146}
 147
 148impl Principal {
 149    fn user(&self) -> &User {
 150        match self {
 151            Principal::User(user) => user,
 152            Principal::Impersonated { user, .. } => user,
 153        }
 154    }
 155
 156    fn update_span(&self, span: &tracing::Span) {
 157        match &self {
 158            Principal::User(user) => {
 159                span.record("user_id", user.id.0);
 160                span.record("login", &user.github_login);
 161            }
 162            Principal::Impersonated { user, admin } => {
 163                span.record("user_id", user.id.0);
 164                span.record("login", &user.github_login);
 165                span.record("impersonator", &admin.github_login);
 166            }
 167        }
 168    }
 169}
 170
 171#[derive(Clone)]
 172struct MessageContext {
 173    session: Session,
 174    span: tracing::Span,
 175}
 176
 177impl Deref for MessageContext {
 178    type Target = Session;
 179
 180    fn deref(&self) -> &Self::Target {
 181        &self.session
 182    }
 183}
 184
 185impl MessageContext {
 186    pub fn forward_request<T: RequestMessage>(
 187        &self,
 188        receiver_id: ConnectionId,
 189        request: T,
 190    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 191        let request_start_time = Instant::now();
 192        let span = self.span.clone();
 193        tracing::info!("start forwarding request");
 194        self.peer
 195            .forward_request(self.connection_id, receiver_id, request)
 196            .inspect(move |_| {
 197                span.record(
 198                    HOST_WAITING_MS,
 199                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 200                );
 201            })
 202            .inspect_err(|_| tracing::error!("error forwarding request"))
 203            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 204    }
 205}
 206
 207#[derive(Clone)]
 208struct Session {
 209    principal: Principal,
 210    connection_id: ConnectionId,
 211    db: Arc<tokio::sync::Mutex<DbHandle>>,
 212    peer: Arc<Peer>,
 213    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 214    app_state: Arc<AppState>,
 215    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 216    /// The GeoIP country code for the user.
 217    #[allow(unused)]
 218    geoip_country_code: Option<String>,
 219    #[allow(unused)]
 220    system_id: Option<String>,
 221    _executor: Executor,
 222}
 223
 224impl Session {
 225    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 226        #[cfg(test)]
 227        tokio::task::yield_now().await;
 228        let guard = self.db.lock().await;
 229        #[cfg(test)]
 230        tokio::task::yield_now().await;
 231        guard
 232    }
 233
 234    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 235        #[cfg(test)]
 236        tokio::task::yield_now().await;
 237        let guard = self.connection_pool.lock();
 238        ConnectionPoolGuard {
 239            guard,
 240            _not_send: PhantomData,
 241        }
 242    }
 243
 244    fn is_staff(&self) -> bool {
 245        match &self.principal {
 246            Principal::User(user) => user.admin,
 247            Principal::Impersonated { .. } => true,
 248        }
 249    }
 250
 251    fn user_id(&self) -> UserId {
 252        match &self.principal {
 253            Principal::User(user) => user.id,
 254            Principal::Impersonated { user, .. } => user.id,
 255        }
 256    }
 257
 258    pub fn email(&self) -> Option<String> {
 259        match &self.principal {
 260            Principal::User(user) => user.email_address.clone(),
 261            Principal::Impersonated { user, .. } => user.email_address.clone(),
 262        }
 263    }
 264}
 265
 266impl Debug for Session {
 267    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 268        let mut result = f.debug_struct("Session");
 269        match &self.principal {
 270            Principal::User(user) => {
 271                result.field("user", &user.github_login);
 272            }
 273            Principal::Impersonated { user, admin } => {
 274                result.field("user", &user.github_login);
 275                result.field("impersonator", &admin.github_login);
 276            }
 277        }
 278        result.field("connection_id", &self.connection_id).finish()
 279    }
 280}
 281
 282struct DbHandle(Arc<Database>);
 283
 284impl Deref for DbHandle {
 285    type Target = Database;
 286
 287    fn deref(&self) -> &Self::Target {
 288        self.0.as_ref()
 289    }
 290}
 291
 292pub struct Server {
 293    id: parking_lot::Mutex<ServerId>,
 294    peer: Arc<Peer>,
 295    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 296    app_state: Arc<AppState>,
 297    handlers: HashMap<TypeId, MessageHandler>,
 298    teardown: watch::Sender<bool>,
 299}
 300
 301pub(crate) struct ConnectionPoolGuard<'a> {
 302    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 303    _not_send: PhantomData<Rc<()>>,
 304}
 305
 306#[derive(Serialize)]
 307pub struct ServerSnapshot<'a> {
 308    peer: &'a Peer,
 309    #[serde(serialize_with = "serialize_deref")]
 310    connection_pool: ConnectionPoolGuard<'a>,
 311}
 312
 313pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 314where
 315    S: Serializer,
 316    T: Deref<Target = U>,
 317    U: Serialize,
 318{
 319    Serialize::serialize(value.deref(), serializer)
 320}
 321
 322impl Server {
 323    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 324        let mut server = Self {
 325            id: parking_lot::Mutex::new(id),
 326            peer: Peer::new(id.0 as u32),
 327            app_state: app_state.clone(),
 328            connection_pool: Default::default(),
 329            handlers: Default::default(),
 330            teardown: watch::channel(false).0,
 331        };
 332
 333        server
 334            .add_request_handler(ping)
 335            .add_request_handler(create_room)
 336            .add_request_handler(join_room)
 337            .add_request_handler(rejoin_room)
 338            .add_request_handler(leave_room)
 339            .add_request_handler(set_room_participant_role)
 340            .add_request_handler(call)
 341            .add_request_handler(cancel_call)
 342            .add_message_handler(decline_call)
 343            .add_request_handler(update_participant_location)
 344            .add_request_handler(share_project)
 345            .add_message_handler(unshare_project)
 346            .add_request_handler(join_project)
 347            .add_message_handler(leave_project)
 348            .add_request_handler(update_project)
 349            .add_request_handler(update_worktree)
 350            .add_request_handler(update_repository)
 351            .add_request_handler(remove_repository)
 352            .add_message_handler(start_language_server)
 353            .add_message_handler(update_language_server)
 354            .add_message_handler(update_diagnostic_summary)
 355            .add_message_handler(update_worktree_settings)
 356            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 357            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 358            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 359            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 360            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 361            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 362            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 363            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 364            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 365            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 366            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 367            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 368            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 369            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 370            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 371            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 372            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 373            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 374            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 375            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 376            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 377            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 378            .add_request_handler(
 379                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 380            )
 381            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 382            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 383            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 384            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 385            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 386            .add_request_handler(
 387                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 388            )
 389            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 390            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 391            .add_request_handler(
 392                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 393            )
 394            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 395            .add_request_handler(
 396                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 397            )
 398            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 399            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 400            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 401            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 402            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 403            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 404            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 405            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 406            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 407            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 408            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 409            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 410            .add_request_handler(
 411                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 412            )
 413            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 414            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 415            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 416            .add_request_handler(multi_lsp_query)
 417            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 418            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 419            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 420            .add_message_handler(create_buffer_for_peer)
 421            .add_request_handler(update_buffer)
 422            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 423            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 424            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 425            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 426            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 427            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 428            .add_message_handler(
 429                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 430            )
 431            .add_request_handler(get_users)
 432            .add_request_handler(fuzzy_search_users)
 433            .add_request_handler(request_contact)
 434            .add_request_handler(remove_contact)
 435            .add_request_handler(respond_to_contact_request)
 436            .add_message_handler(subscribe_to_channels)
 437            .add_request_handler(create_channel)
 438            .add_request_handler(delete_channel)
 439            .add_request_handler(invite_channel_member)
 440            .add_request_handler(remove_channel_member)
 441            .add_request_handler(set_channel_member_role)
 442            .add_request_handler(set_channel_visibility)
 443            .add_request_handler(rename_channel)
 444            .add_request_handler(join_channel_buffer)
 445            .add_request_handler(leave_channel_buffer)
 446            .add_message_handler(update_channel_buffer)
 447            .add_request_handler(rejoin_channel_buffers)
 448            .add_request_handler(get_channel_members)
 449            .add_request_handler(respond_to_channel_invite)
 450            .add_request_handler(join_channel)
 451            .add_request_handler(join_channel_chat)
 452            .add_message_handler(leave_channel_chat)
 453            .add_request_handler(send_channel_message)
 454            .add_request_handler(remove_channel_message)
 455            .add_request_handler(update_channel_message)
 456            .add_request_handler(get_channel_messages)
 457            .add_request_handler(get_channel_messages_by_id)
 458            .add_request_handler(get_notifications)
 459            .add_request_handler(mark_notification_as_read)
 460            .add_request_handler(move_channel)
 461            .add_request_handler(reorder_channel)
 462            .add_request_handler(follow)
 463            .add_message_handler(unfollow)
 464            .add_message_handler(update_followers)
 465            .add_request_handler(get_private_user_info)
 466            .add_request_handler(accept_terms_of_service)
 467            .add_message_handler(acknowledge_channel_message)
 468            .add_message_handler(acknowledge_buffer_version)
 469            .add_request_handler(get_supermaven_api_key)
 470            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 471            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 472            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 473            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 474            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 475            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 476            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 477            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 478            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 479            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 480            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 481            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 482            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 483            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 484            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 485            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 486            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 487            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 488            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 489            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 490            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 491            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 492            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 493            .add_message_handler(update_context);
 494
 495        Arc::new(server)
 496    }
 497
 498    pub async fn start(&self) -> Result<()> {
 499        let server_id = *self.id.lock();
 500        let app_state = self.app_state.clone();
 501        let peer = self.peer.clone();
 502        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 503        let pool = self.connection_pool.clone();
 504        let livekit_client = self.app_state.livekit_client.clone();
 505
 506        let span = info_span!("start server");
 507        self.app_state.executor.spawn_detached(
 508            async move {
 509                tracing::info!("waiting for cleanup timeout");
 510                timeout.await;
 511                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 512
 513                app_state
 514                    .db
 515                    .delete_stale_channel_chat_participants(
 516                        &app_state.config.zed_environment,
 517                        server_id,
 518                    )
 519                    .await
 520                    .trace_err();
 521
 522                if let Some((room_ids, channel_ids)) = app_state
 523                    .db
 524                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 525                    .await
 526                    .trace_err()
 527                {
 528                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 529                    tracing::info!(
 530                        stale_channel_buffer_count = channel_ids.len(),
 531                        "retrieved stale channel buffers"
 532                    );
 533
 534                    for channel_id in channel_ids {
 535                        if let Some(refreshed_channel_buffer) = app_state
 536                            .db
 537                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 538                            .await
 539                            .trace_err()
 540                        {
 541                            for connection_id in refreshed_channel_buffer.connection_ids {
 542                                peer.send(
 543                                    connection_id,
 544                                    proto::UpdateChannelBufferCollaborators {
 545                                        channel_id: channel_id.to_proto(),
 546                                        collaborators: refreshed_channel_buffer
 547                                            .collaborators
 548                                            .clone(),
 549                                    },
 550                                )
 551                                .trace_err();
 552                            }
 553                        }
 554                    }
 555
 556                    for room_id in room_ids {
 557                        let mut contacts_to_update = HashSet::default();
 558                        let mut canceled_calls_to_user_ids = Vec::new();
 559                        let mut livekit_room = String::new();
 560                        let mut delete_livekit_room = false;
 561
 562                        if let Some(mut refreshed_room) = app_state
 563                            .db
 564                            .clear_stale_room_participants(room_id, server_id)
 565                            .await
 566                            .trace_err()
 567                        {
 568                            tracing::info!(
 569                                room_id = room_id.0,
 570                                new_participant_count = refreshed_room.room.participants.len(),
 571                                "refreshed room"
 572                            );
 573                            room_updated(&refreshed_room.room, &peer);
 574                            if let Some(channel) = refreshed_room.channel.as_ref() {
 575                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 576                            }
 577                            contacts_to_update
 578                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 579                            contacts_to_update
 580                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 581                            canceled_calls_to_user_ids =
 582                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 583                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 584                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 585                        }
 586
 587                        {
 588                            let pool = pool.lock();
 589                            for canceled_user_id in canceled_calls_to_user_ids {
 590                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 591                                    peer.send(
 592                                        connection_id,
 593                                        proto::CallCanceled {
 594                                            room_id: room_id.to_proto(),
 595                                        },
 596                                    )
 597                                    .trace_err();
 598                                }
 599                            }
 600                        }
 601
 602                        for user_id in contacts_to_update {
 603                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 604                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 605                            if let Some((busy, contacts)) = busy.zip(contacts) {
 606                                let pool = pool.lock();
 607                                let updated_contact = contact_for_user(user_id, busy, &pool);
 608                                for contact in contacts {
 609                                    if let db::Contact::Accepted {
 610                                        user_id: contact_user_id,
 611                                        ..
 612                                    } = contact
 613                                    {
 614                                        for contact_conn_id in
 615                                            pool.user_connection_ids(contact_user_id)
 616                                        {
 617                                            peer.send(
 618                                                contact_conn_id,
 619                                                proto::UpdateContacts {
 620                                                    contacts: vec![updated_contact.clone()],
 621                                                    remove_contacts: Default::default(),
 622                                                    incoming_requests: Default::default(),
 623                                                    remove_incoming_requests: Default::default(),
 624                                                    outgoing_requests: Default::default(),
 625                                                    remove_outgoing_requests: Default::default(),
 626                                                },
 627                                            )
 628                                            .trace_err();
 629                                        }
 630                                    }
 631                                }
 632                            }
 633                        }
 634
 635                        if let Some(live_kit) = livekit_client.as_ref() {
 636                            if delete_livekit_room {
 637                                live_kit.delete_room(livekit_room).await.trace_err();
 638                            }
 639                        }
 640                    }
 641                }
 642
 643                app_state
 644                    .db
 645                    .delete_stale_channel_chat_participants(
 646                        &app_state.config.zed_environment,
 647                        server_id,
 648                    )
 649                    .await
 650                    .trace_err();
 651
 652                app_state
 653                    .db
 654                    .clear_old_worktree_entries(server_id)
 655                    .await
 656                    .trace_err();
 657
 658                app_state
 659                    .db
 660                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 661                    .await
 662                    .trace_err();
 663            }
 664            .instrument(span),
 665        );
 666        Ok(())
 667    }
 668
 669    pub fn teardown(&self) {
 670        self.peer.teardown();
 671        self.connection_pool.lock().reset();
 672        let _ = self.teardown.send(true);
 673    }
 674
 675    #[cfg(test)]
 676    pub fn reset(&self, id: ServerId) {
 677        self.teardown();
 678        *self.id.lock() = id;
 679        self.peer.reset(id.0 as u32);
 680        let _ = self.teardown.send(false);
 681    }
 682
 683    #[cfg(test)]
 684    pub fn id(&self) -> ServerId {
 685        *self.id.lock()
 686    }
 687
 688    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 689    where
 690        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 691        Fut: 'static + Send + Future<Output = Result<()>>,
 692        M: EnvelopedMessage,
 693    {
 694        let prev_handler = self.handlers.insert(
 695            TypeId::of::<M>(),
 696            Box::new(move |envelope, session, span| {
 697                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 698                let received_at = envelope.received_at;
 699                tracing::info!("message received");
 700                let start_time = Instant::now();
 701                let future = (handler)(
 702                    *envelope,
 703                    MessageContext {
 704                        session,
 705                        span: span.clone(),
 706                    },
 707                );
 708                async move {
 709                    let result = future.await;
 710                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 711                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 712                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 713                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 714                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 715                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 716                    match result {
 717                        Err(error) => {
 718                            tracing::error!(?error, "error handling message")
 719                        }
 720                        Ok(()) => tracing::info!("finished handling message"),
 721                    }
 722                }
 723                .boxed()
 724            }),
 725        );
 726        if prev_handler.is_some() {
 727            panic!("registered a handler for the same message twice");
 728        }
 729        self
 730    }
 731
 732    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 733    where
 734        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 735        Fut: 'static + Send + Future<Output = Result<()>>,
 736        M: EnvelopedMessage,
 737    {
 738        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 739        self
 740    }
 741
 742    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 743    where
 744        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 745        Fut: Send + Future<Output = Result<()>>,
 746        M: RequestMessage,
 747    {
 748        let handler = Arc::new(handler);
 749        self.add_handler(move |envelope, session| {
 750            let receipt = envelope.receipt();
 751            let handler = handler.clone();
 752            async move {
 753                let peer = session.peer.clone();
 754                let responded = Arc::new(AtomicBool::default());
 755                let response = Response {
 756                    peer: peer.clone(),
 757                    responded: responded.clone(),
 758                    receipt,
 759                };
 760                match (handler)(envelope.payload, response, session).await {
 761                    Ok(()) => {
 762                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 763                            Ok(())
 764                        } else {
 765                            Err(anyhow!("handler did not send a response"))?
 766                        }
 767                    }
 768                    Err(error) => {
 769                        let proto_err = match &error {
 770                            Error::Internal(err) => err.to_proto(),
 771                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 772                        };
 773                        peer.respond_with_error(receipt, proto_err)?;
 774                        Err(error)
 775                    }
 776                }
 777            }
 778        })
 779    }
 780
 781    pub fn handle_connection(
 782        self: &Arc<Self>,
 783        connection: Connection,
 784        address: String,
 785        principal: Principal,
 786        zed_version: ZedVersion,
 787        release_channel: Option<String>,
 788        user_agent: Option<String>,
 789        geoip_country_code: Option<String>,
 790        system_id: Option<String>,
 791        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 792        executor: Executor,
 793        connection_guard: Option<ConnectionGuard>,
 794    ) -> impl Future<Output = ()> + use<> {
 795        let this = self.clone();
 796        let span = info_span!("handle connection", %address,
 797            connection_id=field::Empty,
 798            user_id=field::Empty,
 799            login=field::Empty,
 800            impersonator=field::Empty,
 801            user_agent=field::Empty,
 802            geoip_country_code=field::Empty,
 803            release_channel=field::Empty,
 804        );
 805        principal.update_span(&span);
 806        if let Some(user_agent) = user_agent {
 807            span.record("user_agent", user_agent);
 808        }
 809        if let Some(release_channel) = release_channel {
 810            span.record("release_channel", release_channel);
 811        }
 812
 813        if let Some(country_code) = geoip_country_code.as_ref() {
 814            span.record("geoip_country_code", country_code);
 815        }
 816
 817        let mut teardown = self.teardown.subscribe();
 818        async move {
 819            if *teardown.borrow() {
 820                tracing::error!("server is tearing down");
 821                return;
 822            }
 823
 824            let (connection_id, handle_io, mut incoming_rx) =
 825                this.peer.add_connection(connection, {
 826                    let executor = executor.clone();
 827                    move |duration| executor.sleep(duration)
 828                });
 829            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 830
 831            tracing::info!("connection opened");
 832
 833            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 834            let http_client = match ReqwestClient::user_agent(&user_agent) {
 835                Ok(http_client) => Arc::new(http_client),
 836                Err(error) => {
 837                    tracing::error!(?error, "failed to create HTTP client");
 838                    return;
 839                }
 840            };
 841
 842            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 843                |supermaven_admin_api_key| {
 844                    Arc::new(SupermavenAdminApi::new(
 845                        supermaven_admin_api_key.to_string(),
 846                        http_client.clone(),
 847                    ))
 848                },
 849            );
 850
 851            let session = Session {
 852                principal: principal.clone(),
 853                connection_id,
 854                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 855                peer: this.peer.clone(),
 856                connection_pool: this.connection_pool.clone(),
 857                app_state: this.app_state.clone(),
 858                geoip_country_code,
 859                system_id,
 860                _executor: executor.clone(),
 861                supermaven_client,
 862            };
 863
 864            if let Err(error) = this
 865                .send_initial_client_update(
 866                    connection_id,
 867                    zed_version,
 868                    send_connection_id,
 869                    &session,
 870                )
 871                .await
 872            {
 873                tracing::error!(?error, "failed to send initial client update");
 874                return;
 875            }
 876            drop(connection_guard);
 877
 878            let handle_io = handle_io.fuse();
 879            futures::pin_mut!(handle_io);
 880
 881            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 882            // This prevents deadlocks when e.g., client A performs a request to client B and
 883            // client B performs a request to client A. If both clients stop processing further
 884            // messages until their respective request completes, they won't have a chance to
 885            // respond to the other client's request and cause a deadlock.
 886            //
 887            // This arrangement ensures we will attempt to process earlier messages first, but fall
 888            // back to processing messages arrived later in the spirit of making progress.
 889            const MAX_CONCURRENT_HANDLERS: usize = 256;
 890            let mut foreground_message_handlers = FuturesUnordered::new();
 891            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 892            let get_concurrent_handlers = {
 893                let concurrent_handlers = concurrent_handlers.clone();
 894                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 895            };
 896            loop {
 897                let next_message = async {
 898                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 899                    let message = incoming_rx.next().await;
 900                    // Cache the concurrent_handlers here, so that we know what the
 901                    // queue looks like as each handler starts
 902                    (permit, message, get_concurrent_handlers())
 903                }
 904                .fuse();
 905                futures::pin_mut!(next_message);
 906                futures::select_biased! {
 907                    _ = teardown.changed().fuse() => return,
 908                    result = handle_io => {
 909                        if let Err(error) = result {
 910                            tracing::error!(?error, "error handling I/O");
 911                        }
 912                        break;
 913                    }
 914                    _ = foreground_message_handlers.next() => {}
 915                    next_message = next_message => {
 916                        let (permit, message, concurrent_handlers) = next_message;
 917                        if let Some(message) = message {
 918                            let type_name = message.payload_type_name();
 919                            // note: we copy all the fields from the parent span so we can query them in the logs.
 920                            // (https://github.com/tokio-rs/tracing/issues/2670).
 921                            let span = tracing::info_span!("receive message",
 922                                %connection_id,
 923                                %address,
 924                                type_name,
 925                                concurrent_handlers,
 926                                user_id=field::Empty,
 927                                login=field::Empty,
 928                                impersonator=field::Empty,
 929                                multi_lsp_query_request=field::Empty,
 930                                release_channel=field::Empty,
 931                                { TOTAL_DURATION_MS }=field::Empty,
 932                                { PROCESSING_DURATION_MS }=field::Empty,
 933                                { QUEUE_DURATION_MS }=field::Empty,
 934                                { HOST_WAITING_MS }=field::Empty
 935                            );
 936                            principal.update_span(&span);
 937                            let span_enter = span.enter();
 938                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 939                                let is_background = message.is_background();
 940                                let handle_message = (handler)(message, session.clone(), span.clone());
 941                                drop(span_enter);
 942
 943                                let handle_message = async move {
 944                                    handle_message.await;
 945                                    drop(permit);
 946                                }.instrument(span);
 947                                if is_background {
 948                                    executor.spawn_detached(handle_message);
 949                                } else {
 950                                    foreground_message_handlers.push(handle_message);
 951                                }
 952                            } else {
 953                                tracing::error!("no message handler");
 954                            }
 955                        } else {
 956                            tracing::info!("connection closed");
 957                            break;
 958                        }
 959                    }
 960                }
 961            }
 962
 963            drop(foreground_message_handlers);
 964            let concurrent_handlers = get_concurrent_handlers();
 965            tracing::info!(concurrent_handlers, "signing out");
 966            if let Err(error) = connection_lost(session, teardown, executor).await {
 967                tracing::error!(?error, "error signing out");
 968            }
 969        }
 970        .instrument(span)
 971    }
 972
 973    async fn send_initial_client_update(
 974        &self,
 975        connection_id: ConnectionId,
 976        zed_version: ZedVersion,
 977        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 978        session: &Session,
 979    ) -> Result<()> {
 980        self.peer.send(
 981            connection_id,
 982            proto::Hello {
 983                peer_id: Some(connection_id.into()),
 984            },
 985        )?;
 986        tracing::info!("sent hello message");
 987        if let Some(send_connection_id) = send_connection_id.take() {
 988            let _ = send_connection_id.send(connection_id);
 989        }
 990
 991        match &session.principal {
 992            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 993                if !user.connected_once {
 994                    self.peer.send(connection_id, proto::ShowContacts {})?;
 995                    self.app_state
 996                        .db
 997                        .set_user_connected_once(user.id, true)
 998                        .await?;
 999                }
1000
1001                update_user_plan(session).await?;
1002
1003                let contacts = self.app_state.db.get_contacts(user.id).await?;
1004
1005                {
1006                    let mut pool = self.connection_pool.lock();
1007                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1008                    self.peer.send(
1009                        connection_id,
1010                        build_initial_contacts_update(contacts, &pool),
1011                    )?;
1012                }
1013
1014                if should_auto_subscribe_to_channels(zed_version) {
1015                    subscribe_user_to_channels(user.id, session).await?;
1016                }
1017
1018                if let Some(incoming_call) =
1019                    self.app_state.db.incoming_call_for_user(user.id).await?
1020                {
1021                    self.peer.send(connection_id, incoming_call)?;
1022                }
1023
1024                update_user_contacts(user.id, session).await?;
1025            }
1026        }
1027
1028        Ok(())
1029    }
1030
1031    pub async fn invite_code_redeemed(
1032        self: &Arc<Self>,
1033        inviter_id: UserId,
1034        invitee_id: UserId,
1035    ) -> Result<()> {
1036        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1037            if let Some(code) = &user.invite_code {
1038                let pool = self.connection_pool.lock();
1039                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1040                for connection_id in pool.user_connection_ids(inviter_id) {
1041                    self.peer.send(
1042                        connection_id,
1043                        proto::UpdateContacts {
1044                            contacts: vec![invitee_contact.clone()],
1045                            ..Default::default()
1046                        },
1047                    )?;
1048                    self.peer.send(
1049                        connection_id,
1050                        proto::UpdateInviteInfo {
1051                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1052                            count: user.invite_count as u32,
1053                        },
1054                    )?;
1055                }
1056            }
1057        }
1058        Ok(())
1059    }
1060
1061    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1062        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1063            if let Some(invite_code) = &user.invite_code {
1064                let pool = self.connection_pool.lock();
1065                for connection_id in pool.user_connection_ids(user_id) {
1066                    self.peer.send(
1067                        connection_id,
1068                        proto::UpdateInviteInfo {
1069                            url: format!(
1070                                "{}{}",
1071                                self.app_state.config.invite_link_prefix, invite_code
1072                            ),
1073                            count: user.invite_count as u32,
1074                        },
1075                    )?;
1076                }
1077            }
1078        }
1079        Ok(())
1080    }
1081
1082    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1083        ServerSnapshot {
1084            connection_pool: ConnectionPoolGuard {
1085                guard: self.connection_pool.lock(),
1086                _not_send: PhantomData,
1087            },
1088            peer: &self.peer,
1089        }
1090    }
1091}
1092
1093impl Deref for ConnectionPoolGuard<'_> {
1094    type Target = ConnectionPool;
1095
1096    fn deref(&self) -> &Self::Target {
1097        &self.guard
1098    }
1099}
1100
1101impl DerefMut for ConnectionPoolGuard<'_> {
1102    fn deref_mut(&mut self) -> &mut Self::Target {
1103        &mut self.guard
1104    }
1105}
1106
1107impl Drop for ConnectionPoolGuard<'_> {
1108    fn drop(&mut self) {
1109        #[cfg(test)]
1110        self.check_invariants();
1111    }
1112}
1113
1114fn broadcast<F>(
1115    sender_id: Option<ConnectionId>,
1116    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1117    mut f: F,
1118) where
1119    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1120{
1121    for receiver_id in receiver_ids {
1122        if Some(receiver_id) != sender_id {
1123            if let Err(error) = f(receiver_id) {
1124                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1125            }
1126        }
1127    }
1128}
1129
1130pub struct ProtocolVersion(u32);
1131
1132impl Header for ProtocolVersion {
1133    fn name() -> &'static HeaderName {
1134        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1135        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1136    }
1137
1138    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1139    where
1140        Self: Sized,
1141        I: Iterator<Item = &'i axum::http::HeaderValue>,
1142    {
1143        let version = values
1144            .next()
1145            .ok_or_else(axum::headers::Error::invalid)?
1146            .to_str()
1147            .map_err(|_| axum::headers::Error::invalid())?
1148            .parse()
1149            .map_err(|_| axum::headers::Error::invalid())?;
1150        Ok(Self(version))
1151    }
1152
1153    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1154        values.extend([self.0.to_string().parse().unwrap()]);
1155    }
1156}
1157
1158pub struct AppVersionHeader(SemanticVersion);
1159impl Header for AppVersionHeader {
1160    fn name() -> &'static HeaderName {
1161        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1162        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1163    }
1164
1165    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1166    where
1167        Self: Sized,
1168        I: Iterator<Item = &'i axum::http::HeaderValue>,
1169    {
1170        let version = values
1171            .next()
1172            .ok_or_else(axum::headers::Error::invalid)?
1173            .to_str()
1174            .map_err(|_| axum::headers::Error::invalid())?
1175            .parse()
1176            .map_err(|_| axum::headers::Error::invalid())?;
1177        Ok(Self(version))
1178    }
1179
1180    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1181        values.extend([self.0.to_string().parse().unwrap()]);
1182    }
1183}
1184
1185#[derive(Debug)]
1186pub struct ReleaseChannelHeader(String);
1187
1188impl Header for ReleaseChannelHeader {
1189    fn name() -> &'static HeaderName {
1190        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1191        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1192    }
1193
1194    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1195    where
1196        Self: Sized,
1197        I: Iterator<Item = &'i axum::http::HeaderValue>,
1198    {
1199        Ok(Self(
1200            values
1201                .next()
1202                .ok_or_else(axum::headers::Error::invalid)?
1203                .to_str()
1204                .map_err(|_| axum::headers::Error::invalid())?
1205                .to_owned(),
1206        ))
1207    }
1208
1209    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1210        values.extend([self.0.parse().unwrap()]);
1211    }
1212}
1213
1214pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1215    Router::new()
1216        .route("/rpc", get(handle_websocket_request))
1217        .layer(
1218            ServiceBuilder::new()
1219                .layer(Extension(server.app_state.clone()))
1220                .layer(middleware::from_fn(auth::validate_header)),
1221        )
1222        .route("/metrics", get(handle_metrics))
1223        .layer(Extension(server))
1224}
1225
1226pub async fn handle_websocket_request(
1227    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1228    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1229    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1230    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1231    Extension(server): Extension<Arc<Server>>,
1232    Extension(principal): Extension<Principal>,
1233    user_agent: Option<TypedHeader<UserAgent>>,
1234    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1235    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1236    ws: WebSocketUpgrade,
1237) -> axum::response::Response {
1238    if protocol_version != rpc::PROTOCOL_VERSION {
1239        return (
1240            StatusCode::UPGRADE_REQUIRED,
1241            "client must be upgraded".to_string(),
1242        )
1243            .into_response();
1244    }
1245
1246    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1247        return (
1248            StatusCode::UPGRADE_REQUIRED,
1249            "no version header found".to_string(),
1250        )
1251            .into_response();
1252    };
1253
1254    let release_channel = release_channel_header.map(|header| header.0.0);
1255
1256    if !version.can_collaborate() {
1257        return (
1258            StatusCode::UPGRADE_REQUIRED,
1259            "client must be upgraded".to_string(),
1260        )
1261            .into_response();
1262    }
1263
1264    let socket_address = socket_address.to_string();
1265
1266    // Acquire connection guard before WebSocket upgrade
1267    let connection_guard = match ConnectionGuard::try_acquire() {
1268        Ok(guard) => guard,
1269        Err(()) => {
1270            return (
1271                StatusCode::SERVICE_UNAVAILABLE,
1272                "Too many concurrent connections",
1273            )
1274                .into_response();
1275        }
1276    };
1277
1278    ws.on_upgrade(move |socket| {
1279        let socket = socket
1280            .map_ok(to_tungstenite_message)
1281            .err_into()
1282            .with(|message| async move { to_axum_message(message) });
1283        let connection = Connection::new(Box::pin(socket));
1284        async move {
1285            server
1286                .handle_connection(
1287                    connection,
1288                    socket_address,
1289                    principal,
1290                    version,
1291                    release_channel,
1292                    user_agent.map(|header| header.to_string()),
1293                    country_code_header.map(|header| header.to_string()),
1294                    system_id_header.map(|header| header.to_string()),
1295                    None,
1296                    Executor::Production,
1297                    Some(connection_guard),
1298                )
1299                .await;
1300        }
1301    })
1302}
1303
1304pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1305    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1306    let connections_metric = CONNECTIONS_METRIC
1307        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1308
1309    let connections = server
1310        .connection_pool
1311        .lock()
1312        .connections()
1313        .filter(|connection| !connection.admin)
1314        .count();
1315    connections_metric.set(connections as _);
1316
1317    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1318    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1319        register_int_gauge!(
1320            "shared_projects",
1321            "number of open projects with one or more guests"
1322        )
1323        .unwrap()
1324    });
1325
1326    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1327    shared_projects_metric.set(shared_projects as _);
1328
1329    let encoder = prometheus::TextEncoder::new();
1330    let metric_families = prometheus::gather();
1331    let encoded_metrics = encoder
1332        .encode_to_string(&metric_families)
1333        .map_err(|err| anyhow!("{err}"))?;
1334    Ok(encoded_metrics)
1335}
1336
1337#[instrument(err, skip(executor))]
1338async fn connection_lost(
1339    session: Session,
1340    mut teardown: watch::Receiver<bool>,
1341    executor: Executor,
1342) -> Result<()> {
1343    session.peer.disconnect(session.connection_id);
1344    session
1345        .connection_pool()
1346        .await
1347        .remove_connection(session.connection_id)?;
1348
1349    session
1350        .db()
1351        .await
1352        .connection_lost(session.connection_id)
1353        .await
1354        .trace_err();
1355
1356    futures::select_biased! {
1357        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1358
1359            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1360            leave_room_for_session(&session, session.connection_id).await.trace_err();
1361            leave_channel_buffers_for_session(&session)
1362                .await
1363                .trace_err();
1364
1365            if !session
1366                .connection_pool()
1367                .await
1368                .is_user_online(session.user_id())
1369            {
1370                let db = session.db().await;
1371                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1372                    room_updated(&room, &session.peer);
1373                }
1374            }
1375
1376            update_user_contacts(session.user_id(), &session).await?;
1377        },
1378        _ = teardown.changed().fuse() => {}
1379    }
1380
1381    Ok(())
1382}
1383
1384/// Acknowledges a ping from a client, used to keep the connection alive.
1385async fn ping(
1386    _: proto::Ping,
1387    response: Response<proto::Ping>,
1388    _session: MessageContext,
1389) -> Result<()> {
1390    response.send(proto::Ack {})?;
1391    Ok(())
1392}
1393
1394/// Creates a new room for calling (outside of channels)
1395async fn create_room(
1396    _request: proto::CreateRoom,
1397    response: Response<proto::CreateRoom>,
1398    session: MessageContext,
1399) -> Result<()> {
1400    let livekit_room = nanoid::nanoid!(30);
1401
1402    let live_kit_connection_info = util::maybe!(async {
1403        let live_kit = session.app_state.livekit_client.as_ref();
1404        let live_kit = live_kit?;
1405        let user_id = session.user_id().to_string();
1406
1407        let token = live_kit
1408            .room_token(&livekit_room, &user_id.to_string())
1409            .trace_err()?;
1410
1411        Some(proto::LiveKitConnectionInfo {
1412            server_url: live_kit.url().into(),
1413            token,
1414            can_publish: true,
1415        })
1416    })
1417    .await;
1418
1419    let room = session
1420        .db()
1421        .await
1422        .create_room(session.user_id(), session.connection_id, &livekit_room)
1423        .await?;
1424
1425    response.send(proto::CreateRoomResponse {
1426        room: Some(room.clone()),
1427        live_kit_connection_info,
1428    })?;
1429
1430    update_user_contacts(session.user_id(), &session).await?;
1431    Ok(())
1432}
1433
1434/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1435async fn join_room(
1436    request: proto::JoinRoom,
1437    response: Response<proto::JoinRoom>,
1438    session: MessageContext,
1439) -> Result<()> {
1440    let room_id = RoomId::from_proto(request.id);
1441
1442    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1443
1444    if let Some(channel_id) = channel_id {
1445        return join_channel_internal(channel_id, Box::new(response), session).await;
1446    }
1447
1448    let joined_room = {
1449        let room = session
1450            .db()
1451            .await
1452            .join_room(room_id, session.user_id(), session.connection_id)
1453            .await?;
1454        room_updated(&room.room, &session.peer);
1455        room.into_inner()
1456    };
1457
1458    for connection_id in session
1459        .connection_pool()
1460        .await
1461        .user_connection_ids(session.user_id())
1462    {
1463        session
1464            .peer
1465            .send(
1466                connection_id,
1467                proto::CallCanceled {
1468                    room_id: room_id.to_proto(),
1469                },
1470            )
1471            .trace_err();
1472    }
1473
1474    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1475    {
1476        live_kit
1477            .room_token(
1478                &joined_room.room.livekit_room,
1479                &session.user_id().to_string(),
1480            )
1481            .trace_err()
1482            .map(|token| proto::LiveKitConnectionInfo {
1483                server_url: live_kit.url().into(),
1484                token,
1485                can_publish: true,
1486            })
1487    } else {
1488        None
1489    };
1490
1491    response.send(proto::JoinRoomResponse {
1492        room: Some(joined_room.room),
1493        channel_id: None,
1494        live_kit_connection_info,
1495    })?;
1496
1497    update_user_contacts(session.user_id(), &session).await?;
1498    Ok(())
1499}
1500
1501/// Rejoin room is used to reconnect to a room after connection errors.
1502async fn rejoin_room(
1503    request: proto::RejoinRoom,
1504    response: Response<proto::RejoinRoom>,
1505    session: MessageContext,
1506) -> Result<()> {
1507    let room;
1508    let channel;
1509    {
1510        let mut rejoined_room = session
1511            .db()
1512            .await
1513            .rejoin_room(request, session.user_id(), session.connection_id)
1514            .await?;
1515
1516        response.send(proto::RejoinRoomResponse {
1517            room: Some(rejoined_room.room.clone()),
1518            reshared_projects: rejoined_room
1519                .reshared_projects
1520                .iter()
1521                .map(|project| proto::ResharedProject {
1522                    id: project.id.to_proto(),
1523                    collaborators: project
1524                        .collaborators
1525                        .iter()
1526                        .map(|collaborator| collaborator.to_proto())
1527                        .collect(),
1528                })
1529                .collect(),
1530            rejoined_projects: rejoined_room
1531                .rejoined_projects
1532                .iter()
1533                .map(|rejoined_project| rejoined_project.to_proto())
1534                .collect(),
1535        })?;
1536        room_updated(&rejoined_room.room, &session.peer);
1537
1538        for project in &rejoined_room.reshared_projects {
1539            for collaborator in &project.collaborators {
1540                session
1541                    .peer
1542                    .send(
1543                        collaborator.connection_id,
1544                        proto::UpdateProjectCollaborator {
1545                            project_id: project.id.to_proto(),
1546                            old_peer_id: Some(project.old_connection_id.into()),
1547                            new_peer_id: Some(session.connection_id.into()),
1548                        },
1549                    )
1550                    .trace_err();
1551            }
1552
1553            broadcast(
1554                Some(session.connection_id),
1555                project
1556                    .collaborators
1557                    .iter()
1558                    .map(|collaborator| collaborator.connection_id),
1559                |connection_id| {
1560                    session.peer.forward_send(
1561                        session.connection_id,
1562                        connection_id,
1563                        proto::UpdateProject {
1564                            project_id: project.id.to_proto(),
1565                            worktrees: project.worktrees.clone(),
1566                        },
1567                    )
1568                },
1569            );
1570        }
1571
1572        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1573
1574        let rejoined_room = rejoined_room.into_inner();
1575
1576        room = rejoined_room.room;
1577        channel = rejoined_room.channel;
1578    }
1579
1580    if let Some(channel) = channel {
1581        channel_updated(
1582            &channel,
1583            &room,
1584            &session.peer,
1585            &*session.connection_pool().await,
1586        );
1587    }
1588
1589    update_user_contacts(session.user_id(), &session).await?;
1590    Ok(())
1591}
1592
1593fn notify_rejoined_projects(
1594    rejoined_projects: &mut Vec<RejoinedProject>,
1595    session: &Session,
1596) -> Result<()> {
1597    for project in rejoined_projects.iter() {
1598        for collaborator in &project.collaborators {
1599            session
1600                .peer
1601                .send(
1602                    collaborator.connection_id,
1603                    proto::UpdateProjectCollaborator {
1604                        project_id: project.id.to_proto(),
1605                        old_peer_id: Some(project.old_connection_id.into()),
1606                        new_peer_id: Some(session.connection_id.into()),
1607                    },
1608                )
1609                .trace_err();
1610        }
1611    }
1612
1613    for project in rejoined_projects {
1614        for worktree in mem::take(&mut project.worktrees) {
1615            // Stream this worktree's entries.
1616            let message = proto::UpdateWorktree {
1617                project_id: project.id.to_proto(),
1618                worktree_id: worktree.id,
1619                abs_path: worktree.abs_path.clone(),
1620                root_name: worktree.root_name,
1621                updated_entries: worktree.updated_entries,
1622                removed_entries: worktree.removed_entries,
1623                scan_id: worktree.scan_id,
1624                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1625                updated_repositories: worktree.updated_repositories,
1626                removed_repositories: worktree.removed_repositories,
1627            };
1628            for update in proto::split_worktree_update(message) {
1629                session.peer.send(session.connection_id, update)?;
1630            }
1631
1632            // Stream this worktree's diagnostics.
1633            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1634            if let Some(summary) = worktree_diagnostics.next() {
1635                let message = proto::UpdateDiagnosticSummary {
1636                    project_id: project.id.to_proto(),
1637                    worktree_id: worktree.id,
1638                    summary: Some(summary),
1639                    more_summaries: worktree_diagnostics.collect(),
1640                };
1641                session.peer.send(session.connection_id, message)?;
1642            }
1643
1644            for settings_file in worktree.settings_files {
1645                session.peer.send(
1646                    session.connection_id,
1647                    proto::UpdateWorktreeSettings {
1648                        project_id: project.id.to_proto(),
1649                        worktree_id: worktree.id,
1650                        path: settings_file.path,
1651                        content: Some(settings_file.content),
1652                        kind: Some(settings_file.kind.to_proto().into()),
1653                    },
1654                )?;
1655            }
1656        }
1657
1658        for repository in mem::take(&mut project.updated_repositories) {
1659            for update in split_repository_update(repository) {
1660                session.peer.send(session.connection_id, update)?;
1661            }
1662        }
1663
1664        for id in mem::take(&mut project.removed_repositories) {
1665            session.peer.send(
1666                session.connection_id,
1667                proto::RemoveRepository {
1668                    project_id: project.id.to_proto(),
1669                    id,
1670                },
1671            )?;
1672        }
1673    }
1674
1675    Ok(())
1676}
1677
1678/// leave room disconnects from the room.
1679async fn leave_room(
1680    _: proto::LeaveRoom,
1681    response: Response<proto::LeaveRoom>,
1682    session: MessageContext,
1683) -> Result<()> {
1684    leave_room_for_session(&session, session.connection_id).await?;
1685    response.send(proto::Ack {})?;
1686    Ok(())
1687}
1688
1689/// Updates the permissions of someone else in the room.
1690async fn set_room_participant_role(
1691    request: proto::SetRoomParticipantRole,
1692    response: Response<proto::SetRoomParticipantRole>,
1693    session: MessageContext,
1694) -> Result<()> {
1695    let user_id = UserId::from_proto(request.user_id);
1696    let role = ChannelRole::from(request.role());
1697
1698    let (livekit_room, can_publish) = {
1699        let room = session
1700            .db()
1701            .await
1702            .set_room_participant_role(
1703                session.user_id(),
1704                RoomId::from_proto(request.room_id),
1705                user_id,
1706                role,
1707            )
1708            .await?;
1709
1710        let livekit_room = room.livekit_room.clone();
1711        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1712        room_updated(&room, &session.peer);
1713        (livekit_room, can_publish)
1714    };
1715
1716    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1717        live_kit
1718            .update_participant(
1719                livekit_room.clone(),
1720                request.user_id.to_string(),
1721                livekit_api::proto::ParticipantPermission {
1722                    can_subscribe: true,
1723                    can_publish,
1724                    can_publish_data: can_publish,
1725                    hidden: false,
1726                    recorder: false,
1727                },
1728            )
1729            .await
1730            .trace_err();
1731    }
1732
1733    response.send(proto::Ack {})?;
1734    Ok(())
1735}
1736
1737/// Call someone else into the current room
1738async fn call(
1739    request: proto::Call,
1740    response: Response<proto::Call>,
1741    session: MessageContext,
1742) -> Result<()> {
1743    let room_id = RoomId::from_proto(request.room_id);
1744    let calling_user_id = session.user_id();
1745    let calling_connection_id = session.connection_id;
1746    let called_user_id = UserId::from_proto(request.called_user_id);
1747    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1748    if !session
1749        .db()
1750        .await
1751        .has_contact(calling_user_id, called_user_id)
1752        .await?
1753    {
1754        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1755    }
1756
1757    let incoming_call = {
1758        let (room, incoming_call) = &mut *session
1759            .db()
1760            .await
1761            .call(
1762                room_id,
1763                calling_user_id,
1764                calling_connection_id,
1765                called_user_id,
1766                initial_project_id,
1767            )
1768            .await?;
1769        room_updated(room, &session.peer);
1770        mem::take(incoming_call)
1771    };
1772    update_user_contacts(called_user_id, &session).await?;
1773
1774    let mut calls = session
1775        .connection_pool()
1776        .await
1777        .user_connection_ids(called_user_id)
1778        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1779        .collect::<FuturesUnordered<_>>();
1780
1781    while let Some(call_response) = calls.next().await {
1782        match call_response.as_ref() {
1783            Ok(_) => {
1784                response.send(proto::Ack {})?;
1785                return Ok(());
1786            }
1787            Err(_) => {
1788                call_response.trace_err();
1789            }
1790        }
1791    }
1792
1793    {
1794        let room = session
1795            .db()
1796            .await
1797            .call_failed(room_id, called_user_id)
1798            .await?;
1799        room_updated(&room, &session.peer);
1800    }
1801    update_user_contacts(called_user_id, &session).await?;
1802
1803    Err(anyhow!("failed to ring user"))?
1804}
1805
1806/// Cancel an outgoing call.
1807async fn cancel_call(
1808    request: proto::CancelCall,
1809    response: Response<proto::CancelCall>,
1810    session: MessageContext,
1811) -> Result<()> {
1812    let called_user_id = UserId::from_proto(request.called_user_id);
1813    let room_id = RoomId::from_proto(request.room_id);
1814    {
1815        let room = session
1816            .db()
1817            .await
1818            .cancel_call(room_id, session.connection_id, called_user_id)
1819            .await?;
1820        room_updated(&room, &session.peer);
1821    }
1822
1823    for connection_id in session
1824        .connection_pool()
1825        .await
1826        .user_connection_ids(called_user_id)
1827    {
1828        session
1829            .peer
1830            .send(
1831                connection_id,
1832                proto::CallCanceled {
1833                    room_id: room_id.to_proto(),
1834                },
1835            )
1836            .trace_err();
1837    }
1838    response.send(proto::Ack {})?;
1839
1840    update_user_contacts(called_user_id, &session).await?;
1841    Ok(())
1842}
1843
1844/// Decline an incoming call.
1845async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1846    let room_id = RoomId::from_proto(message.room_id);
1847    {
1848        let room = session
1849            .db()
1850            .await
1851            .decline_call(Some(room_id), session.user_id())
1852            .await?
1853            .context("declining call")?;
1854        room_updated(&room, &session.peer);
1855    }
1856
1857    for connection_id in session
1858        .connection_pool()
1859        .await
1860        .user_connection_ids(session.user_id())
1861    {
1862        session
1863            .peer
1864            .send(
1865                connection_id,
1866                proto::CallCanceled {
1867                    room_id: room_id.to_proto(),
1868                },
1869            )
1870            .trace_err();
1871    }
1872    update_user_contacts(session.user_id(), &session).await?;
1873    Ok(())
1874}
1875
1876/// Updates other participants in the room with your current location.
1877async fn update_participant_location(
1878    request: proto::UpdateParticipantLocation,
1879    response: Response<proto::UpdateParticipantLocation>,
1880    session: MessageContext,
1881) -> Result<()> {
1882    let room_id = RoomId::from_proto(request.room_id);
1883    let location = request.location.context("invalid location")?;
1884
1885    let db = session.db().await;
1886    let room = db
1887        .update_room_participant_location(room_id, session.connection_id, location)
1888        .await?;
1889
1890    room_updated(&room, &session.peer);
1891    response.send(proto::Ack {})?;
1892    Ok(())
1893}
1894
1895/// Share a project into the room.
1896async fn share_project(
1897    request: proto::ShareProject,
1898    response: Response<proto::ShareProject>,
1899    session: MessageContext,
1900) -> Result<()> {
1901    let (project_id, room) = &*session
1902        .db()
1903        .await
1904        .share_project(
1905            RoomId::from_proto(request.room_id),
1906            session.connection_id,
1907            &request.worktrees,
1908            request.is_ssh_project,
1909        )
1910        .await?;
1911    response.send(proto::ShareProjectResponse {
1912        project_id: project_id.to_proto(),
1913    })?;
1914    room_updated(room, &session.peer);
1915
1916    Ok(())
1917}
1918
1919/// Unshare a project from the room.
1920async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1921    let project_id = ProjectId::from_proto(message.project_id);
1922    unshare_project_internal(project_id, session.connection_id, &session).await
1923}
1924
1925async fn unshare_project_internal(
1926    project_id: ProjectId,
1927    connection_id: ConnectionId,
1928    session: &Session,
1929) -> Result<()> {
1930    let delete = {
1931        let room_guard = session
1932            .db()
1933            .await
1934            .unshare_project(project_id, connection_id)
1935            .await?;
1936
1937        let (delete, room, guest_connection_ids) = &*room_guard;
1938
1939        let message = proto::UnshareProject {
1940            project_id: project_id.to_proto(),
1941        };
1942
1943        broadcast(
1944            Some(connection_id),
1945            guest_connection_ids.iter().copied(),
1946            |conn_id| session.peer.send(conn_id, message.clone()),
1947        );
1948        if let Some(room) = room {
1949            room_updated(room, &session.peer);
1950        }
1951
1952        *delete
1953    };
1954
1955    if delete {
1956        let db = session.db().await;
1957        db.delete_project(project_id).await?;
1958    }
1959
1960    Ok(())
1961}
1962
1963/// Join someone elses shared project.
1964async fn join_project(
1965    request: proto::JoinProject,
1966    response: Response<proto::JoinProject>,
1967    session: MessageContext,
1968) -> Result<()> {
1969    let project_id = ProjectId::from_proto(request.project_id);
1970
1971    tracing::info!(%project_id, "join project");
1972
1973    let db = session.db().await;
1974    let (project, replica_id) = &mut *db
1975        .join_project(
1976            project_id,
1977            session.connection_id,
1978            session.user_id(),
1979            request.committer_name.clone(),
1980            request.committer_email.clone(),
1981        )
1982        .await?;
1983    drop(db);
1984    tracing::info!(%project_id, "join remote project");
1985    let collaborators = project
1986        .collaborators
1987        .iter()
1988        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1989        .map(|collaborator| collaborator.to_proto())
1990        .collect::<Vec<_>>();
1991    let project_id = project.id;
1992    let guest_user_id = session.user_id();
1993
1994    let worktrees = project
1995        .worktrees
1996        .iter()
1997        .map(|(id, worktree)| proto::WorktreeMetadata {
1998            id: *id,
1999            root_name: worktree.root_name.clone(),
2000            visible: worktree.visible,
2001            abs_path: worktree.abs_path.clone(),
2002        })
2003        .collect::<Vec<_>>();
2004
2005    let add_project_collaborator = proto::AddProjectCollaborator {
2006        project_id: project_id.to_proto(),
2007        collaborator: Some(proto::Collaborator {
2008            peer_id: Some(session.connection_id.into()),
2009            replica_id: replica_id.0 as u32,
2010            user_id: guest_user_id.to_proto(),
2011            is_host: false,
2012            committer_name: request.committer_name.clone(),
2013            committer_email: request.committer_email.clone(),
2014        }),
2015    };
2016
2017    for collaborator in &collaborators {
2018        session
2019            .peer
2020            .send(
2021                collaborator.peer_id.unwrap().into(),
2022                add_project_collaborator.clone(),
2023            )
2024            .trace_err();
2025    }
2026
2027    // First, we send the metadata associated with each worktree.
2028    let (language_servers, language_server_capabilities) = project
2029        .language_servers
2030        .clone()
2031        .into_iter()
2032        .map(|server| (server.server, server.capabilities))
2033        .unzip();
2034    response.send(proto::JoinProjectResponse {
2035        project_id: project.id.0 as u64,
2036        worktrees: worktrees.clone(),
2037        replica_id: replica_id.0 as u32,
2038        collaborators: collaborators.clone(),
2039        language_servers,
2040        language_server_capabilities,
2041        role: project.role.into(),
2042    })?;
2043
2044    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2045        // Stream this worktree's entries.
2046        let message = proto::UpdateWorktree {
2047            project_id: project_id.to_proto(),
2048            worktree_id,
2049            abs_path: worktree.abs_path.clone(),
2050            root_name: worktree.root_name,
2051            updated_entries: worktree.entries,
2052            removed_entries: Default::default(),
2053            scan_id: worktree.scan_id,
2054            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2055            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2056            removed_repositories: Default::default(),
2057        };
2058        for update in proto::split_worktree_update(message) {
2059            session.peer.send(session.connection_id, update.clone())?;
2060        }
2061
2062        // Stream this worktree's diagnostics.
2063        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2064        if let Some(summary) = worktree_diagnostics.next() {
2065            let message = proto::UpdateDiagnosticSummary {
2066                project_id: project.id.to_proto(),
2067                worktree_id: worktree.id,
2068                summary: Some(summary),
2069                more_summaries: worktree_diagnostics.collect(),
2070            };
2071            session.peer.send(session.connection_id, message)?;
2072        }
2073
2074        for settings_file in worktree.settings_files {
2075            session.peer.send(
2076                session.connection_id,
2077                proto::UpdateWorktreeSettings {
2078                    project_id: project_id.to_proto(),
2079                    worktree_id: worktree.id,
2080                    path: settings_file.path,
2081                    content: Some(settings_file.content),
2082                    kind: Some(settings_file.kind.to_proto() as i32),
2083                },
2084            )?;
2085        }
2086    }
2087
2088    for repository in mem::take(&mut project.repositories) {
2089        for update in split_repository_update(repository) {
2090            session.peer.send(session.connection_id, update)?;
2091        }
2092    }
2093
2094    for language_server in &project.language_servers {
2095        session.peer.send(
2096            session.connection_id,
2097            proto::UpdateLanguageServer {
2098                project_id: project_id.to_proto(),
2099                server_name: Some(language_server.server.name.clone()),
2100                language_server_id: language_server.server.id,
2101                variant: Some(
2102                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2103                        proto::LspDiskBasedDiagnosticsUpdated {},
2104                    ),
2105                ),
2106            },
2107        )?;
2108    }
2109
2110    Ok(())
2111}
2112
2113/// Leave someone elses shared project.
2114async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2115    let sender_id = session.connection_id;
2116    let project_id = ProjectId::from_proto(request.project_id);
2117    let db = session.db().await;
2118
2119    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2120    tracing::info!(
2121        %project_id,
2122        "leave project"
2123    );
2124
2125    project_left(project, &session);
2126    if let Some(room) = room {
2127        room_updated(room, &session.peer);
2128    }
2129
2130    Ok(())
2131}
2132
2133/// Updates other participants with changes to the project
2134async fn update_project(
2135    request: proto::UpdateProject,
2136    response: Response<proto::UpdateProject>,
2137    session: MessageContext,
2138) -> Result<()> {
2139    let project_id = ProjectId::from_proto(request.project_id);
2140    let (room, guest_connection_ids) = &*session
2141        .db()
2142        .await
2143        .update_project(project_id, session.connection_id, &request.worktrees)
2144        .await?;
2145    broadcast(
2146        Some(session.connection_id),
2147        guest_connection_ids.iter().copied(),
2148        |connection_id| {
2149            session
2150                .peer
2151                .forward_send(session.connection_id, connection_id, request.clone())
2152        },
2153    );
2154    if let Some(room) = room {
2155        room_updated(room, &session.peer);
2156    }
2157    response.send(proto::Ack {})?;
2158
2159    Ok(())
2160}
2161
2162/// Updates other participants with changes to the worktree
2163async fn update_worktree(
2164    request: proto::UpdateWorktree,
2165    response: Response<proto::UpdateWorktree>,
2166    session: MessageContext,
2167) -> Result<()> {
2168    let guest_connection_ids = session
2169        .db()
2170        .await
2171        .update_worktree(&request, session.connection_id)
2172        .await?;
2173
2174    broadcast(
2175        Some(session.connection_id),
2176        guest_connection_ids.iter().copied(),
2177        |connection_id| {
2178            session
2179                .peer
2180                .forward_send(session.connection_id, connection_id, request.clone())
2181        },
2182    );
2183    response.send(proto::Ack {})?;
2184    Ok(())
2185}
2186
2187async fn update_repository(
2188    request: proto::UpdateRepository,
2189    response: Response<proto::UpdateRepository>,
2190    session: MessageContext,
2191) -> Result<()> {
2192    let guest_connection_ids = session
2193        .db()
2194        .await
2195        .update_repository(&request, session.connection_id)
2196        .await?;
2197
2198    broadcast(
2199        Some(session.connection_id),
2200        guest_connection_ids.iter().copied(),
2201        |connection_id| {
2202            session
2203                .peer
2204                .forward_send(session.connection_id, connection_id, request.clone())
2205        },
2206    );
2207    response.send(proto::Ack {})?;
2208    Ok(())
2209}
2210
2211async fn remove_repository(
2212    request: proto::RemoveRepository,
2213    response: Response<proto::RemoveRepository>,
2214    session: MessageContext,
2215) -> Result<()> {
2216    let guest_connection_ids = session
2217        .db()
2218        .await
2219        .remove_repository(&request, session.connection_id)
2220        .await?;
2221
2222    broadcast(
2223        Some(session.connection_id),
2224        guest_connection_ids.iter().copied(),
2225        |connection_id| {
2226            session
2227                .peer
2228                .forward_send(session.connection_id, connection_id, request.clone())
2229        },
2230    );
2231    response.send(proto::Ack {})?;
2232    Ok(())
2233}
2234
2235/// Updates other participants with changes to the diagnostics
2236async fn update_diagnostic_summary(
2237    message: proto::UpdateDiagnosticSummary,
2238    session: MessageContext,
2239) -> Result<()> {
2240    let guest_connection_ids = session
2241        .db()
2242        .await
2243        .update_diagnostic_summary(&message, session.connection_id)
2244        .await?;
2245
2246    broadcast(
2247        Some(session.connection_id),
2248        guest_connection_ids.iter().copied(),
2249        |connection_id| {
2250            session
2251                .peer
2252                .forward_send(session.connection_id, connection_id, message.clone())
2253        },
2254    );
2255
2256    Ok(())
2257}
2258
2259/// Updates other participants with changes to the worktree settings
2260async fn update_worktree_settings(
2261    message: proto::UpdateWorktreeSettings,
2262    session: MessageContext,
2263) -> Result<()> {
2264    let guest_connection_ids = session
2265        .db()
2266        .await
2267        .update_worktree_settings(&message, session.connection_id)
2268        .await?;
2269
2270    broadcast(
2271        Some(session.connection_id),
2272        guest_connection_ids.iter().copied(),
2273        |connection_id| {
2274            session
2275                .peer
2276                .forward_send(session.connection_id, connection_id, message.clone())
2277        },
2278    );
2279
2280    Ok(())
2281}
2282
2283/// Notify other participants that a language server has started.
2284async fn start_language_server(
2285    request: proto::StartLanguageServer,
2286    session: MessageContext,
2287) -> Result<()> {
2288    let guest_connection_ids = session
2289        .db()
2290        .await
2291        .start_language_server(&request, session.connection_id)
2292        .await?;
2293
2294    broadcast(
2295        Some(session.connection_id),
2296        guest_connection_ids.iter().copied(),
2297        |connection_id| {
2298            session
2299                .peer
2300                .forward_send(session.connection_id, connection_id, request.clone())
2301        },
2302    );
2303    Ok(())
2304}
2305
2306/// Notify other participants that a language server has changed.
2307async fn update_language_server(
2308    request: proto::UpdateLanguageServer,
2309    session: MessageContext,
2310) -> Result<()> {
2311    let project_id = ProjectId::from_proto(request.project_id);
2312    let db = session.db().await;
2313
2314    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2315    {
2316        if let Some(capabilities) = update.capabilities.clone() {
2317            db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2318                .await?;
2319        }
2320    }
2321
2322    let project_connection_ids = db
2323        .project_connection_ids(project_id, session.connection_id, true)
2324        .await?;
2325    broadcast(
2326        Some(session.connection_id),
2327        project_connection_ids.iter().copied(),
2328        |connection_id| {
2329            session
2330                .peer
2331                .forward_send(session.connection_id, connection_id, request.clone())
2332        },
2333    );
2334    Ok(())
2335}
2336
2337/// forward a project request to the host. These requests should be read only
2338/// as guests are allowed to send them.
2339async fn forward_read_only_project_request<T>(
2340    request: T,
2341    response: Response<T>,
2342    session: MessageContext,
2343) -> Result<()>
2344where
2345    T: EntityMessage + RequestMessage,
2346{
2347    let project_id = ProjectId::from_proto(request.remote_entity_id());
2348    let host_connection_id = session
2349        .db()
2350        .await
2351        .host_for_read_only_project_request(project_id, session.connection_id)
2352        .await?;
2353    let payload = session.forward_request(host_connection_id, request).await?;
2354    response.send(payload)?;
2355    Ok(())
2356}
2357
2358/// forward a project request to the host. These requests are disallowed
2359/// for guests.
2360async fn forward_mutating_project_request<T>(
2361    request: T,
2362    response: Response<T>,
2363    session: MessageContext,
2364) -> Result<()>
2365where
2366    T: EntityMessage + RequestMessage,
2367{
2368    let project_id = ProjectId::from_proto(request.remote_entity_id());
2369
2370    let host_connection_id = session
2371        .db()
2372        .await
2373        .host_for_mutating_project_request(project_id, session.connection_id)
2374        .await?;
2375    let payload = session.forward_request(host_connection_id, request).await?;
2376    response.send(payload)?;
2377    Ok(())
2378}
2379
2380async fn multi_lsp_query(
2381    request: MultiLspQuery,
2382    response: Response<MultiLspQuery>,
2383    session: MessageContext,
2384) -> Result<()> {
2385    tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2386    tracing::info!("multi_lsp_query message received");
2387    forward_mutating_project_request(request, response, session).await
2388}
2389
2390/// Notify other participants that a new buffer has been created
2391async fn create_buffer_for_peer(
2392    request: proto::CreateBufferForPeer,
2393    session: MessageContext,
2394) -> Result<()> {
2395    session
2396        .db()
2397        .await
2398        .check_user_is_project_host(
2399            ProjectId::from_proto(request.project_id),
2400            session.connection_id,
2401        )
2402        .await?;
2403    let peer_id = request.peer_id.context("invalid peer id")?;
2404    session
2405        .peer
2406        .forward_send(session.connection_id, peer_id.into(), request)?;
2407    Ok(())
2408}
2409
2410/// Notify other participants that a buffer has been updated. This is
2411/// allowed for guests as long as the update is limited to selections.
2412async fn update_buffer(
2413    request: proto::UpdateBuffer,
2414    response: Response<proto::UpdateBuffer>,
2415    session: MessageContext,
2416) -> Result<()> {
2417    let project_id = ProjectId::from_proto(request.project_id);
2418    let mut capability = Capability::ReadOnly;
2419
2420    for op in request.operations.iter() {
2421        match op.variant {
2422            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2423            Some(_) => capability = Capability::ReadWrite,
2424        }
2425    }
2426
2427    let host = {
2428        let guard = session
2429            .db()
2430            .await
2431            .connections_for_buffer_update(project_id, session.connection_id, capability)
2432            .await?;
2433
2434        let (host, guests) = &*guard;
2435
2436        broadcast(
2437            Some(session.connection_id),
2438            guests.clone(),
2439            |connection_id| {
2440                session
2441                    .peer
2442                    .forward_send(session.connection_id, connection_id, request.clone())
2443            },
2444        );
2445
2446        *host
2447    };
2448
2449    if host != session.connection_id {
2450        session.forward_request(host, request.clone()).await?;
2451    }
2452
2453    response.send(proto::Ack {})?;
2454    Ok(())
2455}
2456
2457async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2458    let project_id = ProjectId::from_proto(message.project_id);
2459
2460    let operation = message.operation.as_ref().context("invalid operation")?;
2461    let capability = match operation.variant.as_ref() {
2462        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2463            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2464                match buffer_op.variant {
2465                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2466                        Capability::ReadOnly
2467                    }
2468                    _ => Capability::ReadWrite,
2469                }
2470            } else {
2471                Capability::ReadWrite
2472            }
2473        }
2474        Some(_) => Capability::ReadWrite,
2475        None => Capability::ReadOnly,
2476    };
2477
2478    let guard = session
2479        .db()
2480        .await
2481        .connections_for_buffer_update(project_id, session.connection_id, capability)
2482        .await?;
2483
2484    let (host, guests) = &*guard;
2485
2486    broadcast(
2487        Some(session.connection_id),
2488        guests.iter().chain([host]).copied(),
2489        |connection_id| {
2490            session
2491                .peer
2492                .forward_send(session.connection_id, connection_id, message.clone())
2493        },
2494    );
2495
2496    Ok(())
2497}
2498
2499/// Notify other participants that a project has been updated.
2500async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2501    request: T,
2502    session: MessageContext,
2503) -> Result<()> {
2504    let project_id = ProjectId::from_proto(request.remote_entity_id());
2505    let project_connection_ids = session
2506        .db()
2507        .await
2508        .project_connection_ids(project_id, session.connection_id, false)
2509        .await?;
2510
2511    broadcast(
2512        Some(session.connection_id),
2513        project_connection_ids.iter().copied(),
2514        |connection_id| {
2515            session
2516                .peer
2517                .forward_send(session.connection_id, connection_id, request.clone())
2518        },
2519    );
2520    Ok(())
2521}
2522
2523/// Start following another user in a call.
2524async fn follow(
2525    request: proto::Follow,
2526    response: Response<proto::Follow>,
2527    session: MessageContext,
2528) -> Result<()> {
2529    let room_id = RoomId::from_proto(request.room_id);
2530    let project_id = request.project_id.map(ProjectId::from_proto);
2531    let leader_id = request.leader_id.context("invalid leader id")?.into();
2532    let follower_id = session.connection_id;
2533
2534    session
2535        .db()
2536        .await
2537        .check_room_participants(room_id, leader_id, session.connection_id)
2538        .await?;
2539
2540    let response_payload = session.forward_request(leader_id, request).await?;
2541    response.send(response_payload)?;
2542
2543    if let Some(project_id) = project_id {
2544        let room = session
2545            .db()
2546            .await
2547            .follow(room_id, project_id, leader_id, follower_id)
2548            .await?;
2549        room_updated(&room, &session.peer);
2550    }
2551
2552    Ok(())
2553}
2554
2555/// Stop following another user in a call.
2556async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2557    let room_id = RoomId::from_proto(request.room_id);
2558    let project_id = request.project_id.map(ProjectId::from_proto);
2559    let leader_id = request.leader_id.context("invalid leader id")?.into();
2560    let follower_id = session.connection_id;
2561
2562    session
2563        .db()
2564        .await
2565        .check_room_participants(room_id, leader_id, session.connection_id)
2566        .await?;
2567
2568    session
2569        .peer
2570        .forward_send(session.connection_id, leader_id, request)?;
2571
2572    if let Some(project_id) = project_id {
2573        let room = session
2574            .db()
2575            .await
2576            .unfollow(room_id, project_id, leader_id, follower_id)
2577            .await?;
2578        room_updated(&room, &session.peer);
2579    }
2580
2581    Ok(())
2582}
2583
2584/// Notify everyone following you of your current location.
2585async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2586    let room_id = RoomId::from_proto(request.room_id);
2587    let database = session.db.lock().await;
2588
2589    let connection_ids = if let Some(project_id) = request.project_id {
2590        let project_id = ProjectId::from_proto(project_id);
2591        database
2592            .project_connection_ids(project_id, session.connection_id, true)
2593            .await?
2594    } else {
2595        database
2596            .room_connection_ids(room_id, session.connection_id)
2597            .await?
2598    };
2599
2600    // For now, don't send view update messages back to that view's current leader.
2601    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2602        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2603        _ => None,
2604    });
2605
2606    for connection_id in connection_ids.iter().cloned() {
2607        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2608            session
2609                .peer
2610                .forward_send(session.connection_id, connection_id, request.clone())?;
2611        }
2612    }
2613    Ok(())
2614}
2615
2616/// Get public data about users.
2617async fn get_users(
2618    request: proto::GetUsers,
2619    response: Response<proto::GetUsers>,
2620    session: MessageContext,
2621) -> Result<()> {
2622    let user_ids = request
2623        .user_ids
2624        .into_iter()
2625        .map(UserId::from_proto)
2626        .collect();
2627    let users = session
2628        .db()
2629        .await
2630        .get_users_by_ids(user_ids)
2631        .await?
2632        .into_iter()
2633        .map(|user| proto::User {
2634            id: user.id.to_proto(),
2635            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2636            github_login: user.github_login,
2637            name: user.name,
2638        })
2639        .collect();
2640    response.send(proto::UsersResponse { users })?;
2641    Ok(())
2642}
2643
2644/// Search for users (to invite) buy Github login
2645async fn fuzzy_search_users(
2646    request: proto::FuzzySearchUsers,
2647    response: Response<proto::FuzzySearchUsers>,
2648    session: MessageContext,
2649) -> Result<()> {
2650    let query = request.query;
2651    let users = match query.len() {
2652        0 => vec![],
2653        1 | 2 => session
2654            .db()
2655            .await
2656            .get_user_by_github_login(&query)
2657            .await?
2658            .into_iter()
2659            .collect(),
2660        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2661    };
2662    let users = users
2663        .into_iter()
2664        .filter(|user| user.id != session.user_id())
2665        .map(|user| proto::User {
2666            id: user.id.to_proto(),
2667            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2668            github_login: user.github_login,
2669            name: user.name,
2670        })
2671        .collect();
2672    response.send(proto::UsersResponse { users })?;
2673    Ok(())
2674}
2675
2676/// Send a contact request to another user.
2677async fn request_contact(
2678    request: proto::RequestContact,
2679    response: Response<proto::RequestContact>,
2680    session: MessageContext,
2681) -> Result<()> {
2682    let requester_id = session.user_id();
2683    let responder_id = UserId::from_proto(request.responder_id);
2684    if requester_id == responder_id {
2685        return Err(anyhow!("cannot add yourself as a contact"))?;
2686    }
2687
2688    let notifications = session
2689        .db()
2690        .await
2691        .send_contact_request(requester_id, responder_id)
2692        .await?;
2693
2694    // Update outgoing contact requests of requester
2695    let mut update = proto::UpdateContacts::default();
2696    update.outgoing_requests.push(responder_id.to_proto());
2697    for connection_id in session
2698        .connection_pool()
2699        .await
2700        .user_connection_ids(requester_id)
2701    {
2702        session.peer.send(connection_id, update.clone())?;
2703    }
2704
2705    // Update incoming contact requests of responder
2706    let mut update = proto::UpdateContacts::default();
2707    update
2708        .incoming_requests
2709        .push(proto::IncomingContactRequest {
2710            requester_id: requester_id.to_proto(),
2711        });
2712    let connection_pool = session.connection_pool().await;
2713    for connection_id in connection_pool.user_connection_ids(responder_id) {
2714        session.peer.send(connection_id, update.clone())?;
2715    }
2716
2717    send_notifications(&connection_pool, &session.peer, notifications);
2718
2719    response.send(proto::Ack {})?;
2720    Ok(())
2721}
2722
2723/// Accept or decline a contact request
2724async fn respond_to_contact_request(
2725    request: proto::RespondToContactRequest,
2726    response: Response<proto::RespondToContactRequest>,
2727    session: MessageContext,
2728) -> Result<()> {
2729    let responder_id = session.user_id();
2730    let requester_id = UserId::from_proto(request.requester_id);
2731    let db = session.db().await;
2732    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2733        db.dismiss_contact_notification(responder_id, requester_id)
2734            .await?;
2735    } else {
2736        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2737
2738        let notifications = db
2739            .respond_to_contact_request(responder_id, requester_id, accept)
2740            .await?;
2741        let requester_busy = db.is_user_busy(requester_id).await?;
2742        let responder_busy = db.is_user_busy(responder_id).await?;
2743
2744        let pool = session.connection_pool().await;
2745        // Update responder with new contact
2746        let mut update = proto::UpdateContacts::default();
2747        if accept {
2748            update
2749                .contacts
2750                .push(contact_for_user(requester_id, requester_busy, &pool));
2751        }
2752        update
2753            .remove_incoming_requests
2754            .push(requester_id.to_proto());
2755        for connection_id in pool.user_connection_ids(responder_id) {
2756            session.peer.send(connection_id, update.clone())?;
2757        }
2758
2759        // Update requester with new contact
2760        let mut update = proto::UpdateContacts::default();
2761        if accept {
2762            update
2763                .contacts
2764                .push(contact_for_user(responder_id, responder_busy, &pool));
2765        }
2766        update
2767            .remove_outgoing_requests
2768            .push(responder_id.to_proto());
2769
2770        for connection_id in pool.user_connection_ids(requester_id) {
2771            session.peer.send(connection_id, update.clone())?;
2772        }
2773
2774        send_notifications(&pool, &session.peer, notifications);
2775    }
2776
2777    response.send(proto::Ack {})?;
2778    Ok(())
2779}
2780
2781/// Remove a contact.
2782async fn remove_contact(
2783    request: proto::RemoveContact,
2784    response: Response<proto::RemoveContact>,
2785    session: MessageContext,
2786) -> Result<()> {
2787    let requester_id = session.user_id();
2788    let responder_id = UserId::from_proto(request.user_id);
2789    let db = session.db().await;
2790    let (contact_accepted, deleted_notification_id) =
2791        db.remove_contact(requester_id, responder_id).await?;
2792
2793    let pool = session.connection_pool().await;
2794    // Update outgoing contact requests of requester
2795    let mut update = proto::UpdateContacts::default();
2796    if contact_accepted {
2797        update.remove_contacts.push(responder_id.to_proto());
2798    } else {
2799        update
2800            .remove_outgoing_requests
2801            .push(responder_id.to_proto());
2802    }
2803    for connection_id in pool.user_connection_ids(requester_id) {
2804        session.peer.send(connection_id, update.clone())?;
2805    }
2806
2807    // Update incoming contact requests of responder
2808    let mut update = proto::UpdateContacts::default();
2809    if contact_accepted {
2810        update.remove_contacts.push(requester_id.to_proto());
2811    } else {
2812        update
2813            .remove_incoming_requests
2814            .push(requester_id.to_proto());
2815    }
2816    for connection_id in pool.user_connection_ids(responder_id) {
2817        session.peer.send(connection_id, update.clone())?;
2818        if let Some(notification_id) = deleted_notification_id {
2819            session.peer.send(
2820                connection_id,
2821                proto::DeleteNotification {
2822                    notification_id: notification_id.to_proto(),
2823                },
2824            )?;
2825        }
2826    }
2827
2828    response.send(proto::Ack {})?;
2829    Ok(())
2830}
2831
2832fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2833    version.0.minor() < 139
2834}
2835
2836async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2837    if is_staff {
2838        return Ok(proto::Plan::ZedPro);
2839    }
2840
2841    let subscription = db.get_active_billing_subscription(user_id).await?;
2842    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2843
2844    let plan = if let Some(subscription_kind) = subscription_kind {
2845        match subscription_kind {
2846            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2847            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2848            SubscriptionKind::ZedFree => proto::Plan::Free,
2849        }
2850    } else {
2851        proto::Plan::Free
2852    };
2853
2854    Ok(plan)
2855}
2856
2857async fn make_update_user_plan_message(
2858    user: &User,
2859    is_staff: bool,
2860    db: &Arc<Database>,
2861    llm_db: Option<Arc<LlmDatabase>>,
2862) -> Result<proto::UpdateUserPlan> {
2863    let feature_flags = db.get_user_flags(user.id).await?;
2864    let plan = current_plan(db, user.id, is_staff).await?;
2865    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2866    let billing_preferences = db.get_billing_preferences(user.id).await?;
2867
2868    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2869        let subscription = db.get_active_billing_subscription(user.id).await?;
2870
2871        let subscription_period =
2872            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2873
2874        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2875            llm_db
2876                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2877                .await?
2878        } else {
2879            None
2880        };
2881
2882        (subscription_period, usage)
2883    } else {
2884        (None, None)
2885    };
2886
2887    let bypass_account_age_check = feature_flags
2888        .iter()
2889        .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2890    let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2891        && !bypass_account_age_check
2892        && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2893
2894    Ok(proto::UpdateUserPlan {
2895        plan: plan.into(),
2896        trial_started_at: billing_customer
2897            .as_ref()
2898            .and_then(|billing_customer| billing_customer.trial_started_at)
2899            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2900        is_usage_based_billing_enabled: if is_staff {
2901            Some(true)
2902        } else {
2903            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2904        },
2905        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2906            proto::SubscriptionPeriod {
2907                started_at: started_at.timestamp() as u64,
2908                ended_at: ended_at.timestamp() as u64,
2909            }
2910        }),
2911        account_too_young: Some(account_too_young),
2912        has_overdue_invoices: billing_customer
2913            .map(|billing_customer| billing_customer.has_overdue_invoices),
2914        usage: Some(
2915            usage
2916                .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2917                .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2918        ),
2919    })
2920}
2921
2922fn model_requests_limit(
2923    plan: cloud_llm_client::Plan,
2924    feature_flags: &Vec<String>,
2925) -> cloud_llm_client::UsageLimit {
2926    match plan.model_requests_limit() {
2927        cloud_llm_client::UsageLimit::Limited(limit) => {
2928            let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2929                && feature_flags
2930                    .iter()
2931                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2932            {
2933                1_000
2934            } else {
2935                limit
2936            };
2937
2938            cloud_llm_client::UsageLimit::Limited(limit)
2939        }
2940        cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2941    }
2942}
2943
2944fn subscription_usage_to_proto(
2945    plan: proto::Plan,
2946    usage: crate::llm::db::subscription_usage::Model,
2947    feature_flags: &Vec<String>,
2948) -> proto::SubscriptionUsage {
2949    let plan = match plan {
2950        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2951        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2952        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2953    };
2954
2955    proto::SubscriptionUsage {
2956        model_requests_usage_amount: usage.model_requests as u32,
2957        model_requests_usage_limit: Some(proto::UsageLimit {
2958            variant: Some(match model_requests_limit(plan, feature_flags) {
2959                cloud_llm_client::UsageLimit::Limited(limit) => {
2960                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2961                        limit: limit as u32,
2962                    })
2963                }
2964                cloud_llm_client::UsageLimit::Unlimited => {
2965                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2966                }
2967            }),
2968        }),
2969        edit_predictions_usage_amount: usage.edit_predictions as u32,
2970        edit_predictions_usage_limit: Some(proto::UsageLimit {
2971            variant: Some(match plan.edit_predictions_limit() {
2972                cloud_llm_client::UsageLimit::Limited(limit) => {
2973                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2974                        limit: limit as u32,
2975                    })
2976                }
2977                cloud_llm_client::UsageLimit::Unlimited => {
2978                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2979                }
2980            }),
2981        }),
2982    }
2983}
2984
2985fn make_default_subscription_usage(
2986    plan: proto::Plan,
2987    feature_flags: &Vec<String>,
2988) -> proto::SubscriptionUsage {
2989    let plan = match plan {
2990        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2991        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2992        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2993    };
2994
2995    proto::SubscriptionUsage {
2996        model_requests_usage_amount: 0,
2997        model_requests_usage_limit: Some(proto::UsageLimit {
2998            variant: Some(match model_requests_limit(plan, feature_flags) {
2999                cloud_llm_client::UsageLimit::Limited(limit) => {
3000                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3001                        limit: limit as u32,
3002                    })
3003                }
3004                cloud_llm_client::UsageLimit::Unlimited => {
3005                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3006                }
3007            }),
3008        }),
3009        edit_predictions_usage_amount: 0,
3010        edit_predictions_usage_limit: Some(proto::UsageLimit {
3011            variant: Some(match plan.edit_predictions_limit() {
3012                cloud_llm_client::UsageLimit::Limited(limit) => {
3013                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3014                        limit: limit as u32,
3015                    })
3016                }
3017                cloud_llm_client::UsageLimit::Unlimited => {
3018                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3019                }
3020            }),
3021        }),
3022    }
3023}
3024
3025async fn update_user_plan(session: &Session) -> Result<()> {
3026    let db = session.db().await;
3027
3028    let update_user_plan = make_update_user_plan_message(
3029        session.principal.user(),
3030        session.is_staff(),
3031        &db.0,
3032        session.app_state.llm_db.clone(),
3033    )
3034    .await?;
3035
3036    session
3037        .peer
3038        .send(session.connection_id, update_user_plan)
3039        .trace_err();
3040
3041    Ok(())
3042}
3043
3044async fn subscribe_to_channels(
3045    _: proto::SubscribeToChannels,
3046    session: MessageContext,
3047) -> Result<()> {
3048    subscribe_user_to_channels(session.user_id(), &session).await?;
3049    Ok(())
3050}
3051
3052async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3053    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3054    let mut pool = session.connection_pool().await;
3055    for membership in &channels_for_user.channel_memberships {
3056        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3057    }
3058    session.peer.send(
3059        session.connection_id,
3060        build_update_user_channels(&channels_for_user),
3061    )?;
3062    session.peer.send(
3063        session.connection_id,
3064        build_channels_update(channels_for_user),
3065    )?;
3066    Ok(())
3067}
3068
3069/// Creates a new channel.
3070async fn create_channel(
3071    request: proto::CreateChannel,
3072    response: Response<proto::CreateChannel>,
3073    session: MessageContext,
3074) -> Result<()> {
3075    let db = session.db().await;
3076
3077    let parent_id = request.parent_id.map(ChannelId::from_proto);
3078    let (channel, membership) = db
3079        .create_channel(&request.name, parent_id, session.user_id())
3080        .await?;
3081
3082    let root_id = channel.root_id();
3083    let channel = Channel::from_model(channel);
3084
3085    response.send(proto::CreateChannelResponse {
3086        channel: Some(channel.to_proto()),
3087        parent_id: request.parent_id,
3088    })?;
3089
3090    let mut connection_pool = session.connection_pool().await;
3091    if let Some(membership) = membership {
3092        connection_pool.subscribe_to_channel(
3093            membership.user_id,
3094            membership.channel_id,
3095            membership.role,
3096        );
3097        let update = proto::UpdateUserChannels {
3098            channel_memberships: vec![proto::ChannelMembership {
3099                channel_id: membership.channel_id.to_proto(),
3100                role: membership.role.into(),
3101            }],
3102            ..Default::default()
3103        };
3104        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3105            session.peer.send(connection_id, update.clone())?;
3106        }
3107    }
3108
3109    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3110        if !role.can_see_channel(channel.visibility) {
3111            continue;
3112        }
3113
3114        let update = proto::UpdateChannels {
3115            channels: vec![channel.to_proto()],
3116            ..Default::default()
3117        };
3118        session.peer.send(connection_id, update.clone())?;
3119    }
3120
3121    Ok(())
3122}
3123
3124/// Delete a channel
3125async fn delete_channel(
3126    request: proto::DeleteChannel,
3127    response: Response<proto::DeleteChannel>,
3128    session: MessageContext,
3129) -> Result<()> {
3130    let db = session.db().await;
3131
3132    let channel_id = request.channel_id;
3133    let (root_channel, removed_channels) = db
3134        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3135        .await?;
3136    response.send(proto::Ack {})?;
3137
3138    // Notify members of removed channels
3139    let mut update = proto::UpdateChannels::default();
3140    update
3141        .delete_channels
3142        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3143
3144    let connection_pool = session.connection_pool().await;
3145    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3146        session.peer.send(connection_id, update.clone())?;
3147    }
3148
3149    Ok(())
3150}
3151
3152/// Invite someone to join a channel.
3153async fn invite_channel_member(
3154    request: proto::InviteChannelMember,
3155    response: Response<proto::InviteChannelMember>,
3156    session: MessageContext,
3157) -> Result<()> {
3158    let db = session.db().await;
3159    let channel_id = ChannelId::from_proto(request.channel_id);
3160    let invitee_id = UserId::from_proto(request.user_id);
3161    let InviteMemberResult {
3162        channel,
3163        notifications,
3164    } = db
3165        .invite_channel_member(
3166            channel_id,
3167            invitee_id,
3168            session.user_id(),
3169            request.role().into(),
3170        )
3171        .await?;
3172
3173    let update = proto::UpdateChannels {
3174        channel_invitations: vec![channel.to_proto()],
3175        ..Default::default()
3176    };
3177
3178    let connection_pool = session.connection_pool().await;
3179    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3180        session.peer.send(connection_id, update.clone())?;
3181    }
3182
3183    send_notifications(&connection_pool, &session.peer, notifications);
3184
3185    response.send(proto::Ack {})?;
3186    Ok(())
3187}
3188
3189/// remove someone from a channel
3190async fn remove_channel_member(
3191    request: proto::RemoveChannelMember,
3192    response: Response<proto::RemoveChannelMember>,
3193    session: MessageContext,
3194) -> Result<()> {
3195    let db = session.db().await;
3196    let channel_id = ChannelId::from_proto(request.channel_id);
3197    let member_id = UserId::from_proto(request.user_id);
3198
3199    let RemoveChannelMemberResult {
3200        membership_update,
3201        notification_id,
3202    } = db
3203        .remove_channel_member(channel_id, member_id, session.user_id())
3204        .await?;
3205
3206    let mut connection_pool = session.connection_pool().await;
3207    notify_membership_updated(
3208        &mut connection_pool,
3209        membership_update,
3210        member_id,
3211        &session.peer,
3212    );
3213    for connection_id in connection_pool.user_connection_ids(member_id) {
3214        if let Some(notification_id) = notification_id {
3215            session
3216                .peer
3217                .send(
3218                    connection_id,
3219                    proto::DeleteNotification {
3220                        notification_id: notification_id.to_proto(),
3221                    },
3222                )
3223                .trace_err();
3224        }
3225    }
3226
3227    response.send(proto::Ack {})?;
3228    Ok(())
3229}
3230
3231/// Toggle the channel between public and private.
3232/// Care is taken to maintain the invariant that public channels only descend from public channels,
3233/// (though members-only channels can appear at any point in the hierarchy).
3234async fn set_channel_visibility(
3235    request: proto::SetChannelVisibility,
3236    response: Response<proto::SetChannelVisibility>,
3237    session: MessageContext,
3238) -> Result<()> {
3239    let db = session.db().await;
3240    let channel_id = ChannelId::from_proto(request.channel_id);
3241    let visibility = request.visibility().into();
3242
3243    let channel_model = db
3244        .set_channel_visibility(channel_id, visibility, session.user_id())
3245        .await?;
3246    let root_id = channel_model.root_id();
3247    let channel = Channel::from_model(channel_model);
3248
3249    let mut connection_pool = session.connection_pool().await;
3250    for (user_id, role) in connection_pool
3251        .channel_user_ids(root_id)
3252        .collect::<Vec<_>>()
3253        .into_iter()
3254    {
3255        let update = if role.can_see_channel(channel.visibility) {
3256            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3257            proto::UpdateChannels {
3258                channels: vec![channel.to_proto()],
3259                ..Default::default()
3260            }
3261        } else {
3262            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3263            proto::UpdateChannels {
3264                delete_channels: vec![channel.id.to_proto()],
3265                ..Default::default()
3266            }
3267        };
3268
3269        for connection_id in connection_pool.user_connection_ids(user_id) {
3270            session.peer.send(connection_id, update.clone())?;
3271        }
3272    }
3273
3274    response.send(proto::Ack {})?;
3275    Ok(())
3276}
3277
3278/// Alter the role for a user in the channel.
3279async fn set_channel_member_role(
3280    request: proto::SetChannelMemberRole,
3281    response: Response<proto::SetChannelMemberRole>,
3282    session: MessageContext,
3283) -> Result<()> {
3284    let db = session.db().await;
3285    let channel_id = ChannelId::from_proto(request.channel_id);
3286    let member_id = UserId::from_proto(request.user_id);
3287    let result = db
3288        .set_channel_member_role(
3289            channel_id,
3290            session.user_id(),
3291            member_id,
3292            request.role().into(),
3293        )
3294        .await?;
3295
3296    match result {
3297        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3298            let mut connection_pool = session.connection_pool().await;
3299            notify_membership_updated(
3300                &mut connection_pool,
3301                membership_update,
3302                member_id,
3303                &session.peer,
3304            )
3305        }
3306        db::SetMemberRoleResult::InviteUpdated(channel) => {
3307            let update = proto::UpdateChannels {
3308                channel_invitations: vec![channel.to_proto()],
3309                ..Default::default()
3310            };
3311
3312            for connection_id in session
3313                .connection_pool()
3314                .await
3315                .user_connection_ids(member_id)
3316            {
3317                session.peer.send(connection_id, update.clone())?;
3318            }
3319        }
3320    }
3321
3322    response.send(proto::Ack {})?;
3323    Ok(())
3324}
3325
3326/// Change the name of a channel
3327async fn rename_channel(
3328    request: proto::RenameChannel,
3329    response: Response<proto::RenameChannel>,
3330    session: MessageContext,
3331) -> Result<()> {
3332    let db = session.db().await;
3333    let channel_id = ChannelId::from_proto(request.channel_id);
3334    let channel_model = db
3335        .rename_channel(channel_id, session.user_id(), &request.name)
3336        .await?;
3337    let root_id = channel_model.root_id();
3338    let channel = Channel::from_model(channel_model);
3339
3340    response.send(proto::RenameChannelResponse {
3341        channel: Some(channel.to_proto()),
3342    })?;
3343
3344    let connection_pool = session.connection_pool().await;
3345    let update = proto::UpdateChannels {
3346        channels: vec![channel.to_proto()],
3347        ..Default::default()
3348    };
3349    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3350        if role.can_see_channel(channel.visibility) {
3351            session.peer.send(connection_id, update.clone())?;
3352        }
3353    }
3354
3355    Ok(())
3356}
3357
3358/// Move a channel to a new parent.
3359async fn move_channel(
3360    request: proto::MoveChannel,
3361    response: Response<proto::MoveChannel>,
3362    session: MessageContext,
3363) -> Result<()> {
3364    let channel_id = ChannelId::from_proto(request.channel_id);
3365    let to = ChannelId::from_proto(request.to);
3366
3367    let (root_id, channels) = session
3368        .db()
3369        .await
3370        .move_channel(channel_id, to, session.user_id())
3371        .await?;
3372
3373    let connection_pool = session.connection_pool().await;
3374    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3375        let channels = channels
3376            .iter()
3377            .filter_map(|channel| {
3378                if role.can_see_channel(channel.visibility) {
3379                    Some(channel.to_proto())
3380                } else {
3381                    None
3382                }
3383            })
3384            .collect::<Vec<_>>();
3385        if channels.is_empty() {
3386            continue;
3387        }
3388
3389        let update = proto::UpdateChannels {
3390            channels,
3391            ..Default::default()
3392        };
3393
3394        session.peer.send(connection_id, update.clone())?;
3395    }
3396
3397    response.send(Ack {})?;
3398    Ok(())
3399}
3400
3401async fn reorder_channel(
3402    request: proto::ReorderChannel,
3403    response: Response<proto::ReorderChannel>,
3404    session: MessageContext,
3405) -> Result<()> {
3406    let channel_id = ChannelId::from_proto(request.channel_id);
3407    let direction = request.direction();
3408
3409    let updated_channels = session
3410        .db()
3411        .await
3412        .reorder_channel(channel_id, direction, session.user_id())
3413        .await?;
3414
3415    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3416        let connection_pool = session.connection_pool().await;
3417        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3418            let channels = updated_channels
3419                .iter()
3420                .filter_map(|channel| {
3421                    if role.can_see_channel(channel.visibility) {
3422                        Some(channel.to_proto())
3423                    } else {
3424                        None
3425                    }
3426                })
3427                .collect::<Vec<_>>();
3428
3429            if channels.is_empty() {
3430                continue;
3431            }
3432
3433            let update = proto::UpdateChannels {
3434                channels,
3435                ..Default::default()
3436            };
3437
3438            session.peer.send(connection_id, update.clone())?;
3439        }
3440    }
3441
3442    response.send(Ack {})?;
3443    Ok(())
3444}
3445
3446/// Get the list of channel members
3447async fn get_channel_members(
3448    request: proto::GetChannelMembers,
3449    response: Response<proto::GetChannelMembers>,
3450    session: MessageContext,
3451) -> Result<()> {
3452    let db = session.db().await;
3453    let channel_id = ChannelId::from_proto(request.channel_id);
3454    let limit = if request.limit == 0 {
3455        u16::MAX as u64
3456    } else {
3457        request.limit
3458    };
3459    let (members, users) = db
3460        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3461        .await?;
3462    response.send(proto::GetChannelMembersResponse { members, users })?;
3463    Ok(())
3464}
3465
3466/// Accept or decline a channel invitation.
3467async fn respond_to_channel_invite(
3468    request: proto::RespondToChannelInvite,
3469    response: Response<proto::RespondToChannelInvite>,
3470    session: MessageContext,
3471) -> Result<()> {
3472    let db = session.db().await;
3473    let channel_id = ChannelId::from_proto(request.channel_id);
3474    let RespondToChannelInvite {
3475        membership_update,
3476        notifications,
3477    } = db
3478        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3479        .await?;
3480
3481    let mut connection_pool = session.connection_pool().await;
3482    if let Some(membership_update) = membership_update {
3483        notify_membership_updated(
3484            &mut connection_pool,
3485            membership_update,
3486            session.user_id(),
3487            &session.peer,
3488        );
3489    } else {
3490        let update = proto::UpdateChannels {
3491            remove_channel_invitations: vec![channel_id.to_proto()],
3492            ..Default::default()
3493        };
3494
3495        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3496            session.peer.send(connection_id, update.clone())?;
3497        }
3498    };
3499
3500    send_notifications(&connection_pool, &session.peer, notifications);
3501
3502    response.send(proto::Ack {})?;
3503
3504    Ok(())
3505}
3506
3507/// Join the channels' room
3508async fn join_channel(
3509    request: proto::JoinChannel,
3510    response: Response<proto::JoinChannel>,
3511    session: MessageContext,
3512) -> Result<()> {
3513    let channel_id = ChannelId::from_proto(request.channel_id);
3514    join_channel_internal(channel_id, Box::new(response), session).await
3515}
3516
3517trait JoinChannelInternalResponse {
3518    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3519}
3520impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3521    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3522        Response::<proto::JoinChannel>::send(self, result)
3523    }
3524}
3525impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3526    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3527        Response::<proto::JoinRoom>::send(self, result)
3528    }
3529}
3530
3531async fn join_channel_internal(
3532    channel_id: ChannelId,
3533    response: Box<impl JoinChannelInternalResponse>,
3534    session: MessageContext,
3535) -> Result<()> {
3536    let joined_room = {
3537        let mut db = session.db().await;
3538        // If zed quits without leaving the room, and the user re-opens zed before the
3539        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3540        // room they were in.
3541        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3542            tracing::info!(
3543                stale_connection_id = %connection,
3544                "cleaning up stale connection",
3545            );
3546            drop(db);
3547            leave_room_for_session(&session, connection).await?;
3548            db = session.db().await;
3549        }
3550
3551        let (joined_room, membership_updated, role) = db
3552            .join_channel(channel_id, session.user_id(), session.connection_id)
3553            .await?;
3554
3555        let live_kit_connection_info =
3556            session
3557                .app_state
3558                .livekit_client
3559                .as_ref()
3560                .and_then(|live_kit| {
3561                    let (can_publish, token) = if role == ChannelRole::Guest {
3562                        (
3563                            false,
3564                            live_kit
3565                                .guest_token(
3566                                    &joined_room.room.livekit_room,
3567                                    &session.user_id().to_string(),
3568                                )
3569                                .trace_err()?,
3570                        )
3571                    } else {
3572                        (
3573                            true,
3574                            live_kit
3575                                .room_token(
3576                                    &joined_room.room.livekit_room,
3577                                    &session.user_id().to_string(),
3578                                )
3579                                .trace_err()?,
3580                        )
3581                    };
3582
3583                    Some(LiveKitConnectionInfo {
3584                        server_url: live_kit.url().into(),
3585                        token,
3586                        can_publish,
3587                    })
3588                });
3589
3590        response.send(proto::JoinRoomResponse {
3591            room: Some(joined_room.room.clone()),
3592            channel_id: joined_room
3593                .channel
3594                .as_ref()
3595                .map(|channel| channel.id.to_proto()),
3596            live_kit_connection_info,
3597        })?;
3598
3599        let mut connection_pool = session.connection_pool().await;
3600        if let Some(membership_updated) = membership_updated {
3601            notify_membership_updated(
3602                &mut connection_pool,
3603                membership_updated,
3604                session.user_id(),
3605                &session.peer,
3606            );
3607        }
3608
3609        room_updated(&joined_room.room, &session.peer);
3610
3611        joined_room
3612    };
3613
3614    channel_updated(
3615        &joined_room.channel.context("channel not returned")?,
3616        &joined_room.room,
3617        &session.peer,
3618        &*session.connection_pool().await,
3619    );
3620
3621    update_user_contacts(session.user_id(), &session).await?;
3622    Ok(())
3623}
3624
3625/// Start editing the channel notes
3626async fn join_channel_buffer(
3627    request: proto::JoinChannelBuffer,
3628    response: Response<proto::JoinChannelBuffer>,
3629    session: MessageContext,
3630) -> Result<()> {
3631    let db = session.db().await;
3632    let channel_id = ChannelId::from_proto(request.channel_id);
3633
3634    let open_response = db
3635        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3636        .await?;
3637
3638    let collaborators = open_response.collaborators.clone();
3639    response.send(open_response)?;
3640
3641    let update = UpdateChannelBufferCollaborators {
3642        channel_id: channel_id.to_proto(),
3643        collaborators: collaborators.clone(),
3644    };
3645    channel_buffer_updated(
3646        session.connection_id,
3647        collaborators
3648            .iter()
3649            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3650        &update,
3651        &session.peer,
3652    );
3653
3654    Ok(())
3655}
3656
3657/// Edit the channel notes
3658async fn update_channel_buffer(
3659    request: proto::UpdateChannelBuffer,
3660    session: MessageContext,
3661) -> Result<()> {
3662    let db = session.db().await;
3663    let channel_id = ChannelId::from_proto(request.channel_id);
3664
3665    let (collaborators, epoch, version) = db
3666        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3667        .await?;
3668
3669    channel_buffer_updated(
3670        session.connection_id,
3671        collaborators.clone(),
3672        &proto::UpdateChannelBuffer {
3673            channel_id: channel_id.to_proto(),
3674            operations: request.operations,
3675        },
3676        &session.peer,
3677    );
3678
3679    let pool = &*session.connection_pool().await;
3680
3681    let non_collaborators =
3682        pool.channel_connection_ids(channel_id)
3683            .filter_map(|(connection_id, _)| {
3684                if collaborators.contains(&connection_id) {
3685                    None
3686                } else {
3687                    Some(connection_id)
3688                }
3689            });
3690
3691    broadcast(None, non_collaborators, |peer_id| {
3692        session.peer.send(
3693            peer_id,
3694            proto::UpdateChannels {
3695                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3696                    channel_id: channel_id.to_proto(),
3697                    epoch: epoch as u64,
3698                    version: version.clone(),
3699                }],
3700                ..Default::default()
3701            },
3702        )
3703    });
3704
3705    Ok(())
3706}
3707
3708/// Rejoin the channel notes after a connection blip
3709async fn rejoin_channel_buffers(
3710    request: proto::RejoinChannelBuffers,
3711    response: Response<proto::RejoinChannelBuffers>,
3712    session: MessageContext,
3713) -> Result<()> {
3714    let db = session.db().await;
3715    let buffers = db
3716        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3717        .await?;
3718
3719    for rejoined_buffer in &buffers {
3720        let collaborators_to_notify = rejoined_buffer
3721            .buffer
3722            .collaborators
3723            .iter()
3724            .filter_map(|c| Some(c.peer_id?.into()));
3725        channel_buffer_updated(
3726            session.connection_id,
3727            collaborators_to_notify,
3728            &proto::UpdateChannelBufferCollaborators {
3729                channel_id: rejoined_buffer.buffer.channel_id,
3730                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3731            },
3732            &session.peer,
3733        );
3734    }
3735
3736    response.send(proto::RejoinChannelBuffersResponse {
3737        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3738    })?;
3739
3740    Ok(())
3741}
3742
3743/// Stop editing the channel notes
3744async fn leave_channel_buffer(
3745    request: proto::LeaveChannelBuffer,
3746    response: Response<proto::LeaveChannelBuffer>,
3747    session: MessageContext,
3748) -> Result<()> {
3749    let db = session.db().await;
3750    let channel_id = ChannelId::from_proto(request.channel_id);
3751
3752    let left_buffer = db
3753        .leave_channel_buffer(channel_id, session.connection_id)
3754        .await?;
3755
3756    response.send(Ack {})?;
3757
3758    channel_buffer_updated(
3759        session.connection_id,
3760        left_buffer.connections,
3761        &proto::UpdateChannelBufferCollaborators {
3762            channel_id: channel_id.to_proto(),
3763            collaborators: left_buffer.collaborators,
3764        },
3765        &session.peer,
3766    );
3767
3768    Ok(())
3769}
3770
3771fn channel_buffer_updated<T: EnvelopedMessage>(
3772    sender_id: ConnectionId,
3773    collaborators: impl IntoIterator<Item = ConnectionId>,
3774    message: &T,
3775    peer: &Peer,
3776) {
3777    broadcast(Some(sender_id), collaborators, |peer_id| {
3778        peer.send(peer_id, message.clone())
3779    });
3780}
3781
3782fn send_notifications(
3783    connection_pool: &ConnectionPool,
3784    peer: &Peer,
3785    notifications: db::NotificationBatch,
3786) {
3787    for (user_id, notification) in notifications {
3788        for connection_id in connection_pool.user_connection_ids(user_id) {
3789            if let Err(error) = peer.send(
3790                connection_id,
3791                proto::AddNotification {
3792                    notification: Some(notification.clone()),
3793                },
3794            ) {
3795                tracing::error!(
3796                    "failed to send notification to {:?} {}",
3797                    connection_id,
3798                    error
3799                );
3800            }
3801        }
3802    }
3803}
3804
3805/// Send a message to the channel
3806async fn send_channel_message(
3807    request: proto::SendChannelMessage,
3808    response: Response<proto::SendChannelMessage>,
3809    session: MessageContext,
3810) -> Result<()> {
3811    // Validate the message body.
3812    let body = request.body.trim().to_string();
3813    if body.len() > MAX_MESSAGE_LEN {
3814        return Err(anyhow!("message is too long"))?;
3815    }
3816    if body.is_empty() {
3817        return Err(anyhow!("message can't be blank"))?;
3818    }
3819
3820    // TODO: adjust mentions if body is trimmed
3821
3822    let timestamp = OffsetDateTime::now_utc();
3823    let nonce = request.nonce.context("nonce can't be blank")?;
3824
3825    let channel_id = ChannelId::from_proto(request.channel_id);
3826    let CreatedChannelMessage {
3827        message_id,
3828        participant_connection_ids,
3829        notifications,
3830    } = session
3831        .db()
3832        .await
3833        .create_channel_message(
3834            channel_id,
3835            session.user_id(),
3836            &body,
3837            &request.mentions,
3838            timestamp,
3839            nonce.clone().into(),
3840            request.reply_to_message_id.map(MessageId::from_proto),
3841        )
3842        .await?;
3843
3844    let message = proto::ChannelMessage {
3845        sender_id: session.user_id().to_proto(),
3846        id: message_id.to_proto(),
3847        body,
3848        mentions: request.mentions,
3849        timestamp: timestamp.unix_timestamp() as u64,
3850        nonce: Some(nonce),
3851        reply_to_message_id: request.reply_to_message_id,
3852        edited_at: None,
3853    };
3854    broadcast(
3855        Some(session.connection_id),
3856        participant_connection_ids.clone(),
3857        |connection| {
3858            session.peer.send(
3859                connection,
3860                proto::ChannelMessageSent {
3861                    channel_id: channel_id.to_proto(),
3862                    message: Some(message.clone()),
3863                },
3864            )
3865        },
3866    );
3867    response.send(proto::SendChannelMessageResponse {
3868        message: Some(message),
3869    })?;
3870
3871    let pool = &*session.connection_pool().await;
3872    let non_participants =
3873        pool.channel_connection_ids(channel_id)
3874            .filter_map(|(connection_id, _)| {
3875                if participant_connection_ids.contains(&connection_id) {
3876                    None
3877                } else {
3878                    Some(connection_id)
3879                }
3880            });
3881    broadcast(None, non_participants, |peer_id| {
3882        session.peer.send(
3883            peer_id,
3884            proto::UpdateChannels {
3885                latest_channel_message_ids: vec![proto::ChannelMessageId {
3886                    channel_id: channel_id.to_proto(),
3887                    message_id: message_id.to_proto(),
3888                }],
3889                ..Default::default()
3890            },
3891        )
3892    });
3893    send_notifications(pool, &session.peer, notifications);
3894
3895    Ok(())
3896}
3897
3898/// Delete a channel message
3899async fn remove_channel_message(
3900    request: proto::RemoveChannelMessage,
3901    response: Response<proto::RemoveChannelMessage>,
3902    session: MessageContext,
3903) -> Result<()> {
3904    let channel_id = ChannelId::from_proto(request.channel_id);
3905    let message_id = MessageId::from_proto(request.message_id);
3906    let (connection_ids, existing_notification_ids) = session
3907        .db()
3908        .await
3909        .remove_channel_message(channel_id, message_id, session.user_id())
3910        .await?;
3911
3912    broadcast(
3913        Some(session.connection_id),
3914        connection_ids,
3915        move |connection| {
3916            session.peer.send(connection, request.clone())?;
3917
3918            for notification_id in &existing_notification_ids {
3919                session.peer.send(
3920                    connection,
3921                    proto::DeleteNotification {
3922                        notification_id: (*notification_id).to_proto(),
3923                    },
3924                )?;
3925            }
3926
3927            Ok(())
3928        },
3929    );
3930    response.send(proto::Ack {})?;
3931    Ok(())
3932}
3933
3934async fn update_channel_message(
3935    request: proto::UpdateChannelMessage,
3936    response: Response<proto::UpdateChannelMessage>,
3937    session: MessageContext,
3938) -> Result<()> {
3939    let channel_id = ChannelId::from_proto(request.channel_id);
3940    let message_id = MessageId::from_proto(request.message_id);
3941    let updated_at = OffsetDateTime::now_utc();
3942    let UpdatedChannelMessage {
3943        message_id,
3944        participant_connection_ids,
3945        notifications,
3946        reply_to_message_id,
3947        timestamp,
3948        deleted_mention_notification_ids,
3949        updated_mention_notifications,
3950    } = session
3951        .db()
3952        .await
3953        .update_channel_message(
3954            channel_id,
3955            message_id,
3956            session.user_id(),
3957            request.body.as_str(),
3958            &request.mentions,
3959            updated_at,
3960        )
3961        .await?;
3962
3963    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3964
3965    let message = proto::ChannelMessage {
3966        sender_id: session.user_id().to_proto(),
3967        id: message_id.to_proto(),
3968        body: request.body.clone(),
3969        mentions: request.mentions.clone(),
3970        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3971        nonce: Some(nonce),
3972        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3973        edited_at: Some(updated_at.unix_timestamp() as u64),
3974    };
3975
3976    response.send(proto::Ack {})?;
3977
3978    let pool = &*session.connection_pool().await;
3979    broadcast(
3980        Some(session.connection_id),
3981        participant_connection_ids,
3982        |connection| {
3983            session.peer.send(
3984                connection,
3985                proto::ChannelMessageUpdate {
3986                    channel_id: channel_id.to_proto(),
3987                    message: Some(message.clone()),
3988                },
3989            )?;
3990
3991            for notification_id in &deleted_mention_notification_ids {
3992                session.peer.send(
3993                    connection,
3994                    proto::DeleteNotification {
3995                        notification_id: (*notification_id).to_proto(),
3996                    },
3997                )?;
3998            }
3999
4000            for notification in &updated_mention_notifications {
4001                session.peer.send(
4002                    connection,
4003                    proto::UpdateNotification {
4004                        notification: Some(notification.clone()),
4005                    },
4006                )?;
4007            }
4008
4009            Ok(())
4010        },
4011    );
4012
4013    send_notifications(pool, &session.peer, notifications);
4014
4015    Ok(())
4016}
4017
4018/// Mark a channel message as read
4019async fn acknowledge_channel_message(
4020    request: proto::AckChannelMessage,
4021    session: MessageContext,
4022) -> Result<()> {
4023    let channel_id = ChannelId::from_proto(request.channel_id);
4024    let message_id = MessageId::from_proto(request.message_id);
4025    let notifications = session
4026        .db()
4027        .await
4028        .observe_channel_message(channel_id, session.user_id(), message_id)
4029        .await?;
4030    send_notifications(
4031        &*session.connection_pool().await,
4032        &session.peer,
4033        notifications,
4034    );
4035    Ok(())
4036}
4037
4038/// Mark a buffer version as synced
4039async fn acknowledge_buffer_version(
4040    request: proto::AckBufferOperation,
4041    session: MessageContext,
4042) -> Result<()> {
4043    let buffer_id = BufferId::from_proto(request.buffer_id);
4044    session
4045        .db()
4046        .await
4047        .observe_buffer_version(
4048            buffer_id,
4049            session.user_id(),
4050            request.epoch as i32,
4051            &request.version,
4052        )
4053        .await?;
4054    Ok(())
4055}
4056
4057/// Get a Supermaven API key for the user
4058async fn get_supermaven_api_key(
4059    _request: proto::GetSupermavenApiKey,
4060    response: Response<proto::GetSupermavenApiKey>,
4061    session: MessageContext,
4062) -> Result<()> {
4063    let user_id: String = session.user_id().to_string();
4064    if !session.is_staff() {
4065        return Err(anyhow!("supermaven not enabled for this account"))?;
4066    }
4067
4068    let email = session.email().context("user must have an email")?;
4069
4070    let supermaven_admin_api = session
4071        .supermaven_client
4072        .as_ref()
4073        .context("supermaven not configured")?;
4074
4075    let result = supermaven_admin_api
4076        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4077        .await?;
4078
4079    response.send(proto::GetSupermavenApiKeyResponse {
4080        api_key: result.api_key,
4081    })?;
4082
4083    Ok(())
4084}
4085
4086/// Start receiving chat updates for a channel
4087async fn join_channel_chat(
4088    request: proto::JoinChannelChat,
4089    response: Response<proto::JoinChannelChat>,
4090    session: MessageContext,
4091) -> Result<()> {
4092    let channel_id = ChannelId::from_proto(request.channel_id);
4093
4094    let db = session.db().await;
4095    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4096        .await?;
4097    let messages = db
4098        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4099        .await?;
4100    response.send(proto::JoinChannelChatResponse {
4101        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4102        messages,
4103    })?;
4104    Ok(())
4105}
4106
4107/// Stop receiving chat updates for a channel
4108async fn leave_channel_chat(
4109    request: proto::LeaveChannelChat,
4110    session: MessageContext,
4111) -> Result<()> {
4112    let channel_id = ChannelId::from_proto(request.channel_id);
4113    session
4114        .db()
4115        .await
4116        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4117        .await?;
4118    Ok(())
4119}
4120
4121/// Retrieve the chat history for a channel
4122async fn get_channel_messages(
4123    request: proto::GetChannelMessages,
4124    response: Response<proto::GetChannelMessages>,
4125    session: MessageContext,
4126) -> Result<()> {
4127    let channel_id = ChannelId::from_proto(request.channel_id);
4128    let messages = session
4129        .db()
4130        .await
4131        .get_channel_messages(
4132            channel_id,
4133            session.user_id(),
4134            MESSAGE_COUNT_PER_PAGE,
4135            Some(MessageId::from_proto(request.before_message_id)),
4136        )
4137        .await?;
4138    response.send(proto::GetChannelMessagesResponse {
4139        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4140        messages,
4141    })?;
4142    Ok(())
4143}
4144
4145/// Retrieve specific chat messages
4146async fn get_channel_messages_by_id(
4147    request: proto::GetChannelMessagesById,
4148    response: Response<proto::GetChannelMessagesById>,
4149    session: MessageContext,
4150) -> Result<()> {
4151    let message_ids = request
4152        .message_ids
4153        .iter()
4154        .map(|id| MessageId::from_proto(*id))
4155        .collect::<Vec<_>>();
4156    let messages = session
4157        .db()
4158        .await
4159        .get_channel_messages_by_id(session.user_id(), &message_ids)
4160        .await?;
4161    response.send(proto::GetChannelMessagesResponse {
4162        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4163        messages,
4164    })?;
4165    Ok(())
4166}
4167
4168/// Retrieve the current users notifications
4169async fn get_notifications(
4170    request: proto::GetNotifications,
4171    response: Response<proto::GetNotifications>,
4172    session: MessageContext,
4173) -> Result<()> {
4174    let notifications = session
4175        .db()
4176        .await
4177        .get_notifications(
4178            session.user_id(),
4179            NOTIFICATION_COUNT_PER_PAGE,
4180            request.before_id.map(db::NotificationId::from_proto),
4181        )
4182        .await?;
4183    response.send(proto::GetNotificationsResponse {
4184        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4185        notifications,
4186    })?;
4187    Ok(())
4188}
4189
4190/// Mark notifications as read
4191async fn mark_notification_as_read(
4192    request: proto::MarkNotificationRead,
4193    response: Response<proto::MarkNotificationRead>,
4194    session: MessageContext,
4195) -> Result<()> {
4196    let database = &session.db().await;
4197    let notifications = database
4198        .mark_notification_as_read_by_id(
4199            session.user_id(),
4200            NotificationId::from_proto(request.notification_id),
4201        )
4202        .await?;
4203    send_notifications(
4204        &*session.connection_pool().await,
4205        &session.peer,
4206        notifications,
4207    );
4208    response.send(proto::Ack {})?;
4209    Ok(())
4210}
4211
4212/// Get the current users information
4213async fn get_private_user_info(
4214    _request: proto::GetPrivateUserInfo,
4215    response: Response<proto::GetPrivateUserInfo>,
4216    session: MessageContext,
4217) -> Result<()> {
4218    let db = session.db().await;
4219
4220    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4221    let user = db
4222        .get_user_by_id(session.user_id())
4223        .await?
4224        .context("user not found")?;
4225    let flags = db.get_user_flags(session.user_id()).await?;
4226
4227    response.send(proto::GetPrivateUserInfoResponse {
4228        metrics_id,
4229        staff: user.admin,
4230        flags,
4231        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4232    })?;
4233    Ok(())
4234}
4235
4236/// Accept the terms of service (tos) on behalf of the current user
4237async fn accept_terms_of_service(
4238    _request: proto::AcceptTermsOfService,
4239    response: Response<proto::AcceptTermsOfService>,
4240    session: MessageContext,
4241) -> Result<()> {
4242    let db = session.db().await;
4243
4244    let accepted_tos_at = Utc::now();
4245    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4246        .await?;
4247
4248    response.send(proto::AcceptTermsOfServiceResponse {
4249        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4250    })?;
4251
4252    Ok(())
4253}
4254
4255fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4256    let message = match message {
4257        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4258        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4259        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4260        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4261        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4262            code: frame.code.into(),
4263            reason: frame.reason.as_str().to_owned().into(),
4264        })),
4265        // We should never receive a frame while reading the message, according
4266        // to the `tungstenite` maintainers:
4267        //
4268        // > It cannot occur when you read messages from the WebSocket, but it
4269        // > can be used when you want to send the raw frames (e.g. you want to
4270        // > send the frames to the WebSocket without composing the full message first).
4271        // >
4272        // > — https://github.com/snapview/tungstenite-rs/issues/268
4273        TungsteniteMessage::Frame(_) => {
4274            bail!("received an unexpected frame while reading the message")
4275        }
4276    };
4277
4278    Ok(message)
4279}
4280
4281fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4282    match message {
4283        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4284        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4285        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4286        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4287        AxumMessage::Close(frame) => {
4288            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4289                code: frame.code.into(),
4290                reason: frame.reason.as_ref().into(),
4291            }))
4292        }
4293    }
4294}
4295
4296fn notify_membership_updated(
4297    connection_pool: &mut ConnectionPool,
4298    result: MembershipUpdated,
4299    user_id: UserId,
4300    peer: &Peer,
4301) {
4302    for membership in &result.new_channels.channel_memberships {
4303        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4304    }
4305    for channel_id in &result.removed_channels {
4306        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4307    }
4308
4309    let user_channels_update = proto::UpdateUserChannels {
4310        channel_memberships: result
4311            .new_channels
4312            .channel_memberships
4313            .iter()
4314            .map(|cm| proto::ChannelMembership {
4315                channel_id: cm.channel_id.to_proto(),
4316                role: cm.role.into(),
4317            })
4318            .collect(),
4319        ..Default::default()
4320    };
4321
4322    let mut update = build_channels_update(result.new_channels);
4323    update.delete_channels = result
4324        .removed_channels
4325        .into_iter()
4326        .map(|id| id.to_proto())
4327        .collect();
4328    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4329
4330    for connection_id in connection_pool.user_connection_ids(user_id) {
4331        peer.send(connection_id, user_channels_update.clone())
4332            .trace_err();
4333        peer.send(connection_id, update.clone()).trace_err();
4334    }
4335}
4336
4337fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4338    proto::UpdateUserChannels {
4339        channel_memberships: channels
4340            .channel_memberships
4341            .iter()
4342            .map(|m| proto::ChannelMembership {
4343                channel_id: m.channel_id.to_proto(),
4344                role: m.role.into(),
4345            })
4346            .collect(),
4347        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4348        observed_channel_message_id: channels.observed_channel_messages.clone(),
4349    }
4350}
4351
4352fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4353    let mut update = proto::UpdateChannels::default();
4354
4355    for channel in channels.channels {
4356        update.channels.push(channel.to_proto());
4357    }
4358
4359    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4360    update.latest_channel_message_ids = channels.latest_channel_messages;
4361
4362    for (channel_id, participants) in channels.channel_participants {
4363        update
4364            .channel_participants
4365            .push(proto::ChannelParticipants {
4366                channel_id: channel_id.to_proto(),
4367                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4368            });
4369    }
4370
4371    for channel in channels.invited_channels {
4372        update.channel_invitations.push(channel.to_proto());
4373    }
4374
4375    update
4376}
4377
4378fn build_initial_contacts_update(
4379    contacts: Vec<db::Contact>,
4380    pool: &ConnectionPool,
4381) -> proto::UpdateContacts {
4382    let mut update = proto::UpdateContacts::default();
4383
4384    for contact in contacts {
4385        match contact {
4386            db::Contact::Accepted { user_id, busy } => {
4387                update.contacts.push(contact_for_user(user_id, busy, pool));
4388            }
4389            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4390            db::Contact::Incoming { user_id } => {
4391                update
4392                    .incoming_requests
4393                    .push(proto::IncomingContactRequest {
4394                        requester_id: user_id.to_proto(),
4395                    })
4396            }
4397        }
4398    }
4399
4400    update
4401}
4402
4403fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4404    proto::Contact {
4405        user_id: user_id.to_proto(),
4406        online: pool.is_user_online(user_id),
4407        busy,
4408    }
4409}
4410
4411fn room_updated(room: &proto::Room, peer: &Peer) {
4412    broadcast(
4413        None,
4414        room.participants
4415            .iter()
4416            .filter_map(|participant| Some(participant.peer_id?.into())),
4417        |peer_id| {
4418            peer.send(
4419                peer_id,
4420                proto::RoomUpdated {
4421                    room: Some(room.clone()),
4422                },
4423            )
4424        },
4425    );
4426}
4427
4428fn channel_updated(
4429    channel: &db::channel::Model,
4430    room: &proto::Room,
4431    peer: &Peer,
4432    pool: &ConnectionPool,
4433) {
4434    let participants = room
4435        .participants
4436        .iter()
4437        .map(|p| p.user_id)
4438        .collect::<Vec<_>>();
4439
4440    broadcast(
4441        None,
4442        pool.channel_connection_ids(channel.root_id())
4443            .filter_map(|(channel_id, role)| {
4444                role.can_see_channel(channel.visibility)
4445                    .then_some(channel_id)
4446            }),
4447        |peer_id| {
4448            peer.send(
4449                peer_id,
4450                proto::UpdateChannels {
4451                    channel_participants: vec![proto::ChannelParticipants {
4452                        channel_id: channel.id.to_proto(),
4453                        participant_user_ids: participants.clone(),
4454                    }],
4455                    ..Default::default()
4456                },
4457            )
4458        },
4459    );
4460}
4461
4462async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4463    let db = session.db().await;
4464
4465    let contacts = db.get_contacts(user_id).await?;
4466    let busy = db.is_user_busy(user_id).await?;
4467
4468    let pool = session.connection_pool().await;
4469    let updated_contact = contact_for_user(user_id, busy, &pool);
4470    for contact in contacts {
4471        if let db::Contact::Accepted {
4472            user_id: contact_user_id,
4473            ..
4474        } = contact
4475        {
4476            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4477                session
4478                    .peer
4479                    .send(
4480                        contact_conn_id,
4481                        proto::UpdateContacts {
4482                            contacts: vec![updated_contact.clone()],
4483                            remove_contacts: Default::default(),
4484                            incoming_requests: Default::default(),
4485                            remove_incoming_requests: Default::default(),
4486                            outgoing_requests: Default::default(),
4487                            remove_outgoing_requests: Default::default(),
4488                        },
4489                    )
4490                    .trace_err();
4491            }
4492        }
4493    }
4494    Ok(())
4495}
4496
4497async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4498    let mut contacts_to_update = HashSet::default();
4499
4500    let room_id;
4501    let canceled_calls_to_user_ids;
4502    let livekit_room;
4503    let delete_livekit_room;
4504    let room;
4505    let channel;
4506
4507    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4508        contacts_to_update.insert(session.user_id());
4509
4510        for project in left_room.left_projects.values() {
4511            project_left(project, session);
4512        }
4513
4514        room_id = RoomId::from_proto(left_room.room.id);
4515        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4516        livekit_room = mem::take(&mut left_room.room.livekit_room);
4517        delete_livekit_room = left_room.deleted;
4518        room = mem::take(&mut left_room.room);
4519        channel = mem::take(&mut left_room.channel);
4520
4521        room_updated(&room, &session.peer);
4522    } else {
4523        return Ok(());
4524    }
4525
4526    if let Some(channel) = channel {
4527        channel_updated(
4528            &channel,
4529            &room,
4530            &session.peer,
4531            &*session.connection_pool().await,
4532        );
4533    }
4534
4535    {
4536        let pool = session.connection_pool().await;
4537        for canceled_user_id in canceled_calls_to_user_ids {
4538            for connection_id in pool.user_connection_ids(canceled_user_id) {
4539                session
4540                    .peer
4541                    .send(
4542                        connection_id,
4543                        proto::CallCanceled {
4544                            room_id: room_id.to_proto(),
4545                        },
4546                    )
4547                    .trace_err();
4548            }
4549            contacts_to_update.insert(canceled_user_id);
4550        }
4551    }
4552
4553    for contact_user_id in contacts_to_update {
4554        update_user_contacts(contact_user_id, session).await?;
4555    }
4556
4557    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4558        live_kit
4559            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4560            .await
4561            .trace_err();
4562
4563        if delete_livekit_room {
4564            live_kit.delete_room(livekit_room).await.trace_err();
4565        }
4566    }
4567
4568    Ok(())
4569}
4570
4571async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4572    let left_channel_buffers = session
4573        .db()
4574        .await
4575        .leave_channel_buffers(session.connection_id)
4576        .await?;
4577
4578    for left_buffer in left_channel_buffers {
4579        channel_buffer_updated(
4580            session.connection_id,
4581            left_buffer.connections,
4582            &proto::UpdateChannelBufferCollaborators {
4583                channel_id: left_buffer.channel_id.to_proto(),
4584                collaborators: left_buffer.collaborators,
4585            },
4586            &session.peer,
4587        );
4588    }
4589
4590    Ok(())
4591}
4592
4593fn project_left(project: &db::LeftProject, session: &Session) {
4594    for connection_id in &project.connection_ids {
4595        if project.should_unshare {
4596            session
4597                .peer
4598                .send(
4599                    *connection_id,
4600                    proto::UnshareProject {
4601                        project_id: project.id.to_proto(),
4602                    },
4603                )
4604                .trace_err();
4605        } else {
4606            session
4607                .peer
4608                .send(
4609                    *connection_id,
4610                    proto::RemoveProjectCollaborator {
4611                        project_id: project.id.to_proto(),
4612                        peer_id: Some(session.connection_id.into()),
4613                    },
4614                )
4615                .trace_err();
4616        }
4617    }
4618}
4619
4620pub trait ResultExt {
4621    type Ok;
4622
4623    fn trace_err(self) -> Option<Self::Ok>;
4624}
4625
4626impl<T, E> ResultExt for Result<T, E>
4627where
4628    E: std::fmt::Debug,
4629{
4630    type Ok = T;
4631
4632    #[track_caller]
4633    fn trace_err(self) -> Option<T> {
4634        match self {
4635            Ok(value) => Some(value),
4636            Err(error) => {
4637                tracing::error!("{:?}", error);
4638                None
4639            }
4640        }
4641    }
4642}