rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{
   8    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
   9    MIN_ACCOUNT_AGE_FOR_LLM_USE,
  10};
  11use crate::stripe_client::StripeCustomerId;
  12use crate::{
  13    AppState, Error, Result, auth,
  14    db::{
  15        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  16        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  17        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  18        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  19    },
  20    executor::Executor,
  21};
  22use anyhow::{Context as _, anyhow, bail};
  23use async_tungstenite::tungstenite::{
  24    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  25};
  26use axum::headers::UserAgent;
  27use axum::{
  28    Extension, Router, TypedHeader,
  29    body::Body,
  30    extract::{
  31        ConnectInfo, WebSocketUpgrade,
  32        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  33    },
  34    headers::{Header, HeaderName},
  35    http::StatusCode,
  36    middleware,
  37    response::IntoResponse,
  38    routing::get,
  39};
  40use chrono::Utc;
  41use collections::{HashMap, HashSet};
  42pub use connection_pool::{ConnectionPool, ZedVersion};
  43use core::fmt::{self, Debug, Formatter};
  44use reqwest_client::ReqwestClient;
  45use rpc::proto::{MultiLspQuery, split_repository_update};
  46use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  47
  48use futures::{
  49    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  50    stream::FuturesUnordered,
  51};
  52use prometheus::{IntGauge, register_int_gauge};
  53use rpc::{
  54    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  55    proto::{
  56        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  57        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  58    },
  59};
  60use semantic_version::SemanticVersion;
  61use serde::{Serialize, Serializer};
  62use std::{
  63    any::TypeId,
  64    future::Future,
  65    marker::PhantomData,
  66    mem,
  67    net::SocketAddr,
  68    ops::{Deref, DerefMut},
  69    rc::Rc,
  70    sync::{
  71        Arc, OnceLock,
  72        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  73    },
  74    time::{Duration, Instant},
  75};
  76use time::OffsetDateTime;
  77use tokio::sync::{Semaphore, watch};
  78use tower::ServiceBuilder;
  79use tracing::{
  80    Instrument,
  81    field::{self},
  82    info_span, instrument,
  83};
  84
  85pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  86
  87// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  88pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  89
  90const MESSAGE_COUNT_PER_PAGE: usize = 100;
  91const MAX_MESSAGE_LEN: usize = 1024;
  92const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  93const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  94
  95static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  96
  97type MessageHandler =
  98    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  99
 100pub struct ConnectionGuard;
 101
 102impl ConnectionGuard {
 103    pub fn try_acquire() -> Result<Self, ()> {
 104        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 105        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 106            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 107            tracing::error!(
 108                "too many concurrent connections: {}",
 109                current_connections + 1
 110            );
 111            return Err(());
 112        }
 113        Ok(ConnectionGuard)
 114    }
 115}
 116
 117impl Drop for ConnectionGuard {
 118    fn drop(&mut self) {
 119        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 120    }
 121}
 122
 123struct Response<R> {
 124    peer: Arc<Peer>,
 125    receipt: Receipt<R>,
 126    responded: Arc<AtomicBool>,
 127}
 128
 129impl<R: RequestMessage> Response<R> {
 130    fn send(self, payload: R::Response) -> Result<()> {
 131        self.responded.store(true, SeqCst);
 132        self.peer.respond(self.receipt, payload)?;
 133        Ok(())
 134    }
 135}
 136
 137#[derive(Clone, Debug)]
 138pub enum Principal {
 139    User(User),
 140    Impersonated { user: User, admin: User },
 141}
 142
 143impl Principal {
 144    fn user(&self) -> &User {
 145        match self {
 146            Principal::User(user) => user,
 147            Principal::Impersonated { user, .. } => user,
 148        }
 149    }
 150
 151    fn update_span(&self, span: &tracing::Span) {
 152        match &self {
 153            Principal::User(user) => {
 154                span.record("user_id", user.id.0);
 155                span.record("login", &user.github_login);
 156            }
 157            Principal::Impersonated { user, admin } => {
 158                span.record("user_id", user.id.0);
 159                span.record("login", &user.github_login);
 160                span.record("impersonator", &admin.github_login);
 161            }
 162        }
 163    }
 164}
 165
 166#[derive(Clone)]
 167struct Session {
 168    principal: Principal,
 169    connection_id: ConnectionId,
 170    db: Arc<tokio::sync::Mutex<DbHandle>>,
 171    peer: Arc<Peer>,
 172    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 173    app_state: Arc<AppState>,
 174    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 175    /// The GeoIP country code for the user.
 176    #[allow(unused)]
 177    geoip_country_code: Option<String>,
 178    system_id: Option<String>,
 179    _executor: Executor,
 180}
 181
 182impl Session {
 183    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 184        #[cfg(test)]
 185        tokio::task::yield_now().await;
 186        let guard = self.db.lock().await;
 187        #[cfg(test)]
 188        tokio::task::yield_now().await;
 189        guard
 190    }
 191
 192    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 193        #[cfg(test)]
 194        tokio::task::yield_now().await;
 195        let guard = self.connection_pool.lock();
 196        ConnectionPoolGuard {
 197            guard,
 198            _not_send: PhantomData,
 199        }
 200    }
 201
 202    fn is_staff(&self) -> bool {
 203        match &self.principal {
 204            Principal::User(user) => user.admin,
 205            Principal::Impersonated { .. } => true,
 206        }
 207    }
 208
 209    fn user_id(&self) -> UserId {
 210        match &self.principal {
 211            Principal::User(user) => user.id,
 212            Principal::Impersonated { user, .. } => user.id,
 213        }
 214    }
 215
 216    pub fn email(&self) -> Option<String> {
 217        match &self.principal {
 218            Principal::User(user) => user.email_address.clone(),
 219            Principal::Impersonated { user, .. } => user.email_address.clone(),
 220        }
 221    }
 222}
 223
 224impl Debug for Session {
 225    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 226        let mut result = f.debug_struct("Session");
 227        match &self.principal {
 228            Principal::User(user) => {
 229                result.field("user", &user.github_login);
 230            }
 231            Principal::Impersonated { user, admin } => {
 232                result.field("user", &user.github_login);
 233                result.field("impersonator", &admin.github_login);
 234            }
 235        }
 236        result.field("connection_id", &self.connection_id).finish()
 237    }
 238}
 239
 240struct DbHandle(Arc<Database>);
 241
 242impl Deref for DbHandle {
 243    type Target = Database;
 244
 245    fn deref(&self) -> &Self::Target {
 246        self.0.as_ref()
 247    }
 248}
 249
 250pub struct Server {
 251    id: parking_lot::Mutex<ServerId>,
 252    peer: Arc<Peer>,
 253    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 254    app_state: Arc<AppState>,
 255    handlers: HashMap<TypeId, MessageHandler>,
 256    teardown: watch::Sender<bool>,
 257}
 258
 259pub(crate) struct ConnectionPoolGuard<'a> {
 260    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 261    _not_send: PhantomData<Rc<()>>,
 262}
 263
 264#[derive(Serialize)]
 265pub struct ServerSnapshot<'a> {
 266    peer: &'a Peer,
 267    #[serde(serialize_with = "serialize_deref")]
 268    connection_pool: ConnectionPoolGuard<'a>,
 269}
 270
 271pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 272where
 273    S: Serializer,
 274    T: Deref<Target = U>,
 275    U: Serialize,
 276{
 277    Serialize::serialize(value.deref(), serializer)
 278}
 279
 280impl Server {
 281    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 282        let mut server = Self {
 283            id: parking_lot::Mutex::new(id),
 284            peer: Peer::new(id.0 as u32),
 285            app_state: app_state.clone(),
 286            connection_pool: Default::default(),
 287            handlers: Default::default(),
 288            teardown: watch::channel(false).0,
 289        };
 290
 291        server
 292            .add_request_handler(ping)
 293            .add_request_handler(create_room)
 294            .add_request_handler(join_room)
 295            .add_request_handler(rejoin_room)
 296            .add_request_handler(leave_room)
 297            .add_request_handler(set_room_participant_role)
 298            .add_request_handler(call)
 299            .add_request_handler(cancel_call)
 300            .add_message_handler(decline_call)
 301            .add_request_handler(update_participant_location)
 302            .add_request_handler(share_project)
 303            .add_message_handler(unshare_project)
 304            .add_request_handler(join_project)
 305            .add_message_handler(leave_project)
 306            .add_request_handler(update_project)
 307            .add_request_handler(update_worktree)
 308            .add_request_handler(update_repository)
 309            .add_request_handler(remove_repository)
 310            .add_message_handler(start_language_server)
 311            .add_message_handler(update_language_server)
 312            .add_message_handler(update_diagnostic_summary)
 313            .add_message_handler(update_worktree_settings)
 314            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 316            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 317            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 318            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 319            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 320            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 321            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 322            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 323            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 324            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 325            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 326            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 327            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 328            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 329            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 330            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 331            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 332            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 333            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 334            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 335            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 336            .add_request_handler(
 337                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 338            )
 339            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 340            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 341            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 342            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 344            .add_request_handler(
 345                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 346            )
 347            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 348            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 349            .add_request_handler(
 350                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 351            )
 352            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 353            .add_request_handler(
 354                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 355            )
 356            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 357            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 358            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 359            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 360            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 361            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 362            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 363            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 364            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 365            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 366            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 367            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 368            .add_request_handler(
 369                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 370            )
 371            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 372            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 373            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 374            .add_request_handler(multi_lsp_query)
 375            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 376            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 377            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 378            .add_message_handler(create_buffer_for_peer)
 379            .add_request_handler(update_buffer)
 380            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 381            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 382            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 383            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 384            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 385            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 386            .add_message_handler(
 387                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 388            )
 389            .add_request_handler(get_users)
 390            .add_request_handler(fuzzy_search_users)
 391            .add_request_handler(request_contact)
 392            .add_request_handler(remove_contact)
 393            .add_request_handler(respond_to_contact_request)
 394            .add_message_handler(subscribe_to_channels)
 395            .add_request_handler(create_channel)
 396            .add_request_handler(delete_channel)
 397            .add_request_handler(invite_channel_member)
 398            .add_request_handler(remove_channel_member)
 399            .add_request_handler(set_channel_member_role)
 400            .add_request_handler(set_channel_visibility)
 401            .add_request_handler(rename_channel)
 402            .add_request_handler(join_channel_buffer)
 403            .add_request_handler(leave_channel_buffer)
 404            .add_message_handler(update_channel_buffer)
 405            .add_request_handler(rejoin_channel_buffers)
 406            .add_request_handler(get_channel_members)
 407            .add_request_handler(respond_to_channel_invite)
 408            .add_request_handler(join_channel)
 409            .add_request_handler(join_channel_chat)
 410            .add_message_handler(leave_channel_chat)
 411            .add_request_handler(send_channel_message)
 412            .add_request_handler(remove_channel_message)
 413            .add_request_handler(update_channel_message)
 414            .add_request_handler(get_channel_messages)
 415            .add_request_handler(get_channel_messages_by_id)
 416            .add_request_handler(get_notifications)
 417            .add_request_handler(mark_notification_as_read)
 418            .add_request_handler(move_channel)
 419            .add_request_handler(reorder_channel)
 420            .add_request_handler(follow)
 421            .add_message_handler(unfollow)
 422            .add_message_handler(update_followers)
 423            .add_request_handler(get_private_user_info)
 424            .add_request_handler(get_llm_api_token)
 425            .add_request_handler(accept_terms_of_service)
 426            .add_message_handler(acknowledge_channel_message)
 427            .add_message_handler(acknowledge_buffer_version)
 428            .add_request_handler(get_supermaven_api_key)
 429            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 430            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 431            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 432            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 433            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 434            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 435            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 436            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 437            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 438            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 439            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 440            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 441            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 442            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 443            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 444            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 445            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 446            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 447            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 448            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 449            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 450            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 451            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 452            .add_message_handler(update_context);
 453
 454        Arc::new(server)
 455    }
 456
 457    pub async fn start(&self) -> Result<()> {
 458        let server_id = *self.id.lock();
 459        let app_state = self.app_state.clone();
 460        let peer = self.peer.clone();
 461        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 462        let pool = self.connection_pool.clone();
 463        let livekit_client = self.app_state.livekit_client.clone();
 464
 465        let span = info_span!("start server");
 466        self.app_state.executor.spawn_detached(
 467            async move {
 468                tracing::info!("waiting for cleanup timeout");
 469                timeout.await;
 470                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 471
 472                app_state
 473                    .db
 474                    .delete_stale_channel_chat_participants(
 475                        &app_state.config.zed_environment,
 476                        server_id,
 477                    )
 478                    .await
 479                    .trace_err();
 480
 481                if let Some((room_ids, channel_ids)) = app_state
 482                    .db
 483                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 484                    .await
 485                    .trace_err()
 486                {
 487                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 488                    tracing::info!(
 489                        stale_channel_buffer_count = channel_ids.len(),
 490                        "retrieved stale channel buffers"
 491                    );
 492
 493                    for channel_id in channel_ids {
 494                        if let Some(refreshed_channel_buffer) = app_state
 495                            .db
 496                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 497                            .await
 498                            .trace_err()
 499                        {
 500                            for connection_id in refreshed_channel_buffer.connection_ids {
 501                                peer.send(
 502                                    connection_id,
 503                                    proto::UpdateChannelBufferCollaborators {
 504                                        channel_id: channel_id.to_proto(),
 505                                        collaborators: refreshed_channel_buffer
 506                                            .collaborators
 507                                            .clone(),
 508                                    },
 509                                )
 510                                .trace_err();
 511                            }
 512                        }
 513                    }
 514
 515                    for room_id in room_ids {
 516                        let mut contacts_to_update = HashSet::default();
 517                        let mut canceled_calls_to_user_ids = Vec::new();
 518                        let mut livekit_room = String::new();
 519                        let mut delete_livekit_room = false;
 520
 521                        if let Some(mut refreshed_room) = app_state
 522                            .db
 523                            .clear_stale_room_participants(room_id, server_id)
 524                            .await
 525                            .trace_err()
 526                        {
 527                            tracing::info!(
 528                                room_id = room_id.0,
 529                                new_participant_count = refreshed_room.room.participants.len(),
 530                                "refreshed room"
 531                            );
 532                            room_updated(&refreshed_room.room, &peer);
 533                            if let Some(channel) = refreshed_room.channel.as_ref() {
 534                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 535                            }
 536                            contacts_to_update
 537                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 538                            contacts_to_update
 539                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 540                            canceled_calls_to_user_ids =
 541                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 542                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 543                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 544                        }
 545
 546                        {
 547                            let pool = pool.lock();
 548                            for canceled_user_id in canceled_calls_to_user_ids {
 549                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 550                                    peer.send(
 551                                        connection_id,
 552                                        proto::CallCanceled {
 553                                            room_id: room_id.to_proto(),
 554                                        },
 555                                    )
 556                                    .trace_err();
 557                                }
 558                            }
 559                        }
 560
 561                        for user_id in contacts_to_update {
 562                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 563                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 564                            if let Some((busy, contacts)) = busy.zip(contacts) {
 565                                let pool = pool.lock();
 566                                let updated_contact = contact_for_user(user_id, busy, &pool);
 567                                for contact in contacts {
 568                                    if let db::Contact::Accepted {
 569                                        user_id: contact_user_id,
 570                                        ..
 571                                    } = contact
 572                                    {
 573                                        for contact_conn_id in
 574                                            pool.user_connection_ids(contact_user_id)
 575                                        {
 576                                            peer.send(
 577                                                contact_conn_id,
 578                                                proto::UpdateContacts {
 579                                                    contacts: vec![updated_contact.clone()],
 580                                                    remove_contacts: Default::default(),
 581                                                    incoming_requests: Default::default(),
 582                                                    remove_incoming_requests: Default::default(),
 583                                                    outgoing_requests: Default::default(),
 584                                                    remove_outgoing_requests: Default::default(),
 585                                                },
 586                                            )
 587                                            .trace_err();
 588                                        }
 589                                    }
 590                                }
 591                            }
 592                        }
 593
 594                        if let Some(live_kit) = livekit_client.as_ref() {
 595                            if delete_livekit_room {
 596                                live_kit.delete_room(livekit_room).await.trace_err();
 597                            }
 598                        }
 599                    }
 600                }
 601
 602                app_state
 603                    .db
 604                    .delete_stale_channel_chat_participants(
 605                        &app_state.config.zed_environment,
 606                        server_id,
 607                    )
 608                    .await
 609                    .trace_err();
 610
 611                app_state
 612                    .db
 613                    .clear_old_worktree_entries(server_id)
 614                    .await
 615                    .trace_err();
 616
 617                app_state
 618                    .db
 619                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 620                    .await
 621                    .trace_err();
 622            }
 623            .instrument(span),
 624        );
 625        Ok(())
 626    }
 627
 628    pub fn teardown(&self) {
 629        self.peer.teardown();
 630        self.connection_pool.lock().reset();
 631        let _ = self.teardown.send(true);
 632    }
 633
 634    #[cfg(test)]
 635    pub fn reset(&self, id: ServerId) {
 636        self.teardown();
 637        *self.id.lock() = id;
 638        self.peer.reset(id.0 as u32);
 639        let _ = self.teardown.send(false);
 640    }
 641
 642    #[cfg(test)]
 643    pub fn id(&self) -> ServerId {
 644        *self.id.lock()
 645    }
 646
 647    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 648    where
 649        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 650        Fut: 'static + Send + Future<Output = Result<()>>,
 651        M: EnvelopedMessage,
 652    {
 653        let prev_handler = self.handlers.insert(
 654            TypeId::of::<M>(),
 655            Box::new(move |envelope, session| {
 656                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 657                let received_at = envelope.received_at;
 658                tracing::info!("message received");
 659                let start_time = Instant::now();
 660                let future = (handler)(*envelope, session);
 661                async move {
 662                    let result = future.await;
 663                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 664                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 665                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 666
 667                    match result {
 668                        Err(error) => {
 669                            tracing::error!(
 670                                ?error,
 671                                total_duration_ms,
 672                                processing_duration_ms,
 673                                queue_duration_ms,
 674                                "error handling message"
 675                            )
 676                        }
 677                        Ok(()) => tracing::info!(
 678                            total_duration_ms,
 679                            processing_duration_ms,
 680                            queue_duration_ms,
 681                            "finished handling message"
 682                        ),
 683                    }
 684                }
 685                .boxed()
 686            }),
 687        );
 688        if prev_handler.is_some() {
 689            panic!("registered a handler for the same message twice");
 690        }
 691        self
 692    }
 693
 694    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 695    where
 696        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 697        Fut: 'static + Send + Future<Output = Result<()>>,
 698        M: EnvelopedMessage,
 699    {
 700        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 701        self
 702    }
 703
 704    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 705    where
 706        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 707        Fut: Send + Future<Output = Result<()>>,
 708        M: RequestMessage,
 709    {
 710        let handler = Arc::new(handler);
 711        self.add_handler(move |envelope, session| {
 712            let receipt = envelope.receipt();
 713            let handler = handler.clone();
 714            async move {
 715                let peer = session.peer.clone();
 716                let responded = Arc::new(AtomicBool::default());
 717                let response = Response {
 718                    peer: peer.clone(),
 719                    responded: responded.clone(),
 720                    receipt,
 721                };
 722                match (handler)(envelope.payload, response, session).await {
 723                    Ok(()) => {
 724                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 725                            Ok(())
 726                        } else {
 727                            Err(anyhow!("handler did not send a response"))?
 728                        }
 729                    }
 730                    Err(error) => {
 731                        let proto_err = match &error {
 732                            Error::Internal(err) => err.to_proto(),
 733                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 734                        };
 735                        peer.respond_with_error(receipt, proto_err)?;
 736                        Err(error)
 737                    }
 738                }
 739            }
 740        })
 741    }
 742
 743    pub fn handle_connection(
 744        self: &Arc<Self>,
 745        connection: Connection,
 746        address: String,
 747        principal: Principal,
 748        zed_version: ZedVersion,
 749        release_channel: Option<String>,
 750        user_agent: Option<String>,
 751        geoip_country_code: Option<String>,
 752        system_id: Option<String>,
 753        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 754        executor: Executor,
 755        connection_guard: Option<ConnectionGuard>,
 756    ) -> impl Future<Output = ()> + use<> {
 757        let this = self.clone();
 758        let span = info_span!("handle connection", %address,
 759            connection_id=field::Empty,
 760            user_id=field::Empty,
 761            login=field::Empty,
 762            impersonator=field::Empty,
 763            user_agent=field::Empty,
 764            geoip_country_code=field::Empty,
 765            release_channel=field::Empty,
 766        );
 767        principal.update_span(&span);
 768        if let Some(user_agent) = user_agent {
 769            span.record("user_agent", user_agent);
 770        }
 771        if let Some(release_channel) = release_channel {
 772            span.record("release_channel", release_channel);
 773        }
 774
 775        if let Some(country_code) = geoip_country_code.as_ref() {
 776            span.record("geoip_country_code", country_code);
 777        }
 778
 779        let mut teardown = self.teardown.subscribe();
 780        async move {
 781            if *teardown.borrow() {
 782                tracing::error!("server is tearing down");
 783                return;
 784            }
 785
 786            let (connection_id, handle_io, mut incoming_rx) =
 787                this.peer.add_connection(connection, {
 788                    let executor = executor.clone();
 789                    move |duration| executor.sleep(duration)
 790                });
 791            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 792
 793            tracing::info!("connection opened");
 794
 795            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 796            let http_client = match ReqwestClient::user_agent(&user_agent) {
 797                Ok(http_client) => Arc::new(http_client),
 798                Err(error) => {
 799                    tracing::error!(?error, "failed to create HTTP client");
 800                    return;
 801                }
 802            };
 803
 804            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 805                |supermaven_admin_api_key| {
 806                    Arc::new(SupermavenAdminApi::new(
 807                        supermaven_admin_api_key.to_string(),
 808                        http_client.clone(),
 809                    ))
 810                },
 811            );
 812
 813            let session = Session {
 814                principal: principal.clone(),
 815                connection_id,
 816                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 817                peer: this.peer.clone(),
 818                connection_pool: this.connection_pool.clone(),
 819                app_state: this.app_state.clone(),
 820                geoip_country_code,
 821                system_id,
 822                _executor: executor.clone(),
 823                supermaven_client,
 824            };
 825
 826            if let Err(error) = this
 827                .send_initial_client_update(
 828                    connection_id,
 829                    zed_version,
 830                    send_connection_id,
 831                    &session,
 832                )
 833                .await
 834            {
 835                tracing::error!(?error, "failed to send initial client update");
 836                return;
 837            }
 838            drop(connection_guard);
 839
 840            let handle_io = handle_io.fuse();
 841            futures::pin_mut!(handle_io);
 842
 843            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 844            // This prevents deadlocks when e.g., client A performs a request to client B and
 845            // client B performs a request to client A. If both clients stop processing further
 846            // messages until their respective request completes, they won't have a chance to
 847            // respond to the other client's request and cause a deadlock.
 848            //
 849            // This arrangement ensures we will attempt to process earlier messages first, but fall
 850            // back to processing messages arrived later in the spirit of making progress.
 851            const MAX_CONCURRENT_HANDLERS: usize = 256;
 852            let mut foreground_message_handlers = FuturesUnordered::new();
 853            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 854            let get_concurrent_handlers = {
 855                let concurrent_handlers = concurrent_handlers.clone();
 856                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 857            };
 858            loop {
 859                let next_message = async {
 860                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 861                    let message = incoming_rx.next().await;
 862                    // Cache the concurrent_handlers here, so that we know what the
 863                    // queue looks like as each handler starts
 864                    (permit, message, get_concurrent_handlers())
 865                }
 866                .fuse();
 867                futures::pin_mut!(next_message);
 868                futures::select_biased! {
 869                    _ = teardown.changed().fuse() => return,
 870                    result = handle_io => {
 871                        if let Err(error) = result {
 872                            tracing::error!(?error, "error handling I/O");
 873                        }
 874                        break;
 875                    }
 876                    _ = foreground_message_handlers.next() => {}
 877                    next_message = next_message => {
 878                        let (permit, message, concurrent_handlers) = next_message;
 879                        if let Some(message) = message {
 880                            let type_name = message.payload_type_name();
 881                            // note: we copy all the fields from the parent span so we can query them in the logs.
 882                            // (https://github.com/tokio-rs/tracing/issues/2670).
 883                            let span = tracing::info_span!("receive message",
 884                                %connection_id,
 885                                %address,
 886                                type_name,
 887                                concurrent_handlers,
 888                                user_id=field::Empty,
 889                                login=field::Empty,
 890                                impersonator=field::Empty,
 891                                multi_lsp_query_request=field::Empty,
 892                            );
 893                            principal.update_span(&span);
 894                            let span_enter = span.enter();
 895                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 896                                let is_background = message.is_background();
 897                                let handle_message = (handler)(message, session.clone());
 898                                drop(span_enter);
 899
 900                                let handle_message = async move {
 901                                    handle_message.await;
 902                                    drop(permit);
 903                                }.instrument(span);
 904                                if is_background {
 905                                    executor.spawn_detached(handle_message);
 906                                } else {
 907                                    foreground_message_handlers.push(handle_message);
 908                                }
 909                            } else {
 910                                tracing::error!("no message handler");
 911                            }
 912                        } else {
 913                            tracing::info!("connection closed");
 914                            break;
 915                        }
 916                    }
 917                }
 918            }
 919
 920            drop(foreground_message_handlers);
 921            let concurrent_handlers = get_concurrent_handlers();
 922            tracing::info!(concurrent_handlers, "signing out");
 923            if let Err(error) = connection_lost(session, teardown, executor).await {
 924                tracing::error!(?error, "error signing out");
 925            }
 926        }
 927        .instrument(span)
 928    }
 929
 930    async fn send_initial_client_update(
 931        &self,
 932        connection_id: ConnectionId,
 933        zed_version: ZedVersion,
 934        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 935        session: &Session,
 936    ) -> Result<()> {
 937        self.peer.send(
 938            connection_id,
 939            proto::Hello {
 940                peer_id: Some(connection_id.into()),
 941            },
 942        )?;
 943        tracing::info!("sent hello message");
 944        if let Some(send_connection_id) = send_connection_id.take() {
 945            let _ = send_connection_id.send(connection_id);
 946        }
 947
 948        match &session.principal {
 949            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 950                if !user.connected_once {
 951                    self.peer.send(connection_id, proto::ShowContacts {})?;
 952                    self.app_state
 953                        .db
 954                        .set_user_connected_once(user.id, true)
 955                        .await?;
 956                }
 957
 958                update_user_plan(session).await?;
 959
 960                let contacts = self.app_state.db.get_contacts(user.id).await?;
 961
 962                {
 963                    let mut pool = self.connection_pool.lock();
 964                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 965                    self.peer.send(
 966                        connection_id,
 967                        build_initial_contacts_update(contacts, &pool),
 968                    )?;
 969                }
 970
 971                if should_auto_subscribe_to_channels(zed_version) {
 972                    subscribe_user_to_channels(user.id, session).await?;
 973                }
 974
 975                if let Some(incoming_call) =
 976                    self.app_state.db.incoming_call_for_user(user.id).await?
 977                {
 978                    self.peer.send(connection_id, incoming_call)?;
 979                }
 980
 981                update_user_contacts(user.id, session).await?;
 982            }
 983        }
 984
 985        Ok(())
 986    }
 987
 988    pub async fn invite_code_redeemed(
 989        self: &Arc<Self>,
 990        inviter_id: UserId,
 991        invitee_id: UserId,
 992    ) -> Result<()> {
 993        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 994            if let Some(code) = &user.invite_code {
 995                let pool = self.connection_pool.lock();
 996                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 997                for connection_id in pool.user_connection_ids(inviter_id) {
 998                    self.peer.send(
 999                        connection_id,
1000                        proto::UpdateContacts {
1001                            contacts: vec![invitee_contact.clone()],
1002                            ..Default::default()
1003                        },
1004                    )?;
1005                    self.peer.send(
1006                        connection_id,
1007                        proto::UpdateInviteInfo {
1008                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1009                            count: user.invite_count as u32,
1010                        },
1011                    )?;
1012                }
1013            }
1014        }
1015        Ok(())
1016    }
1017
1018    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1019        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1020            if let Some(invite_code) = &user.invite_code {
1021                let pool = self.connection_pool.lock();
1022                for connection_id in pool.user_connection_ids(user_id) {
1023                    self.peer.send(
1024                        connection_id,
1025                        proto::UpdateInviteInfo {
1026                            url: format!(
1027                                "{}{}",
1028                                self.app_state.config.invite_link_prefix, invite_code
1029                            ),
1030                            count: user.invite_count as u32,
1031                        },
1032                    )?;
1033                }
1034            }
1035        }
1036        Ok(())
1037    }
1038
1039    pub async fn update_plan_for_user(
1040        self: &Arc<Self>,
1041        user_id: UserId,
1042        update_user_plan: proto::UpdateUserPlan,
1043    ) -> Result<()> {
1044        let pool = self.connection_pool.lock();
1045        for connection_id in pool.user_connection_ids(user_id) {
1046            self.peer
1047                .send(connection_id, update_user_plan.clone())
1048                .trace_err();
1049        }
1050
1051        Ok(())
1052    }
1053
1054    /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1055    /// message on the Collab server.
1056    ///
1057    /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1058    pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1059        let user = self
1060            .app_state
1061            .db
1062            .get_user_by_id(user_id)
1063            .await?
1064            .context("user not found")?;
1065
1066        let update_user_plan = make_update_user_plan_message(
1067            &user,
1068            user.admin,
1069            &self.app_state.db,
1070            self.app_state.llm_db.clone(),
1071        )
1072        .await?;
1073
1074        self.update_plan_for_user(user_id, update_user_plan).await
1075    }
1076
1077    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1078        let pool = self.connection_pool.lock();
1079        for connection_id in pool.user_connection_ids(user_id) {
1080            self.peer
1081                .send(connection_id, proto::RefreshLlmToken {})
1082                .trace_err();
1083        }
1084    }
1085
1086    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1087        ServerSnapshot {
1088            connection_pool: ConnectionPoolGuard {
1089                guard: self.connection_pool.lock(),
1090                _not_send: PhantomData,
1091            },
1092            peer: &self.peer,
1093        }
1094    }
1095}
1096
1097impl Deref for ConnectionPoolGuard<'_> {
1098    type Target = ConnectionPool;
1099
1100    fn deref(&self) -> &Self::Target {
1101        &self.guard
1102    }
1103}
1104
1105impl DerefMut for ConnectionPoolGuard<'_> {
1106    fn deref_mut(&mut self) -> &mut Self::Target {
1107        &mut self.guard
1108    }
1109}
1110
1111impl Drop for ConnectionPoolGuard<'_> {
1112    fn drop(&mut self) {
1113        #[cfg(test)]
1114        self.check_invariants();
1115    }
1116}
1117
1118fn broadcast<F>(
1119    sender_id: Option<ConnectionId>,
1120    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1121    mut f: F,
1122) where
1123    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1124{
1125    for receiver_id in receiver_ids {
1126        if Some(receiver_id) != sender_id {
1127            if let Err(error) = f(receiver_id) {
1128                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1129            }
1130        }
1131    }
1132}
1133
1134pub struct ProtocolVersion(u32);
1135
1136impl Header for ProtocolVersion {
1137    fn name() -> &'static HeaderName {
1138        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1139        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1140    }
1141
1142    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1143    where
1144        Self: Sized,
1145        I: Iterator<Item = &'i axum::http::HeaderValue>,
1146    {
1147        let version = values
1148            .next()
1149            .ok_or_else(axum::headers::Error::invalid)?
1150            .to_str()
1151            .map_err(|_| axum::headers::Error::invalid())?
1152            .parse()
1153            .map_err(|_| axum::headers::Error::invalid())?;
1154        Ok(Self(version))
1155    }
1156
1157    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1158        values.extend([self.0.to_string().parse().unwrap()]);
1159    }
1160}
1161
1162pub struct AppVersionHeader(SemanticVersion);
1163impl Header for AppVersionHeader {
1164    fn name() -> &'static HeaderName {
1165        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1166        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1167    }
1168
1169    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1170    where
1171        Self: Sized,
1172        I: Iterator<Item = &'i axum::http::HeaderValue>,
1173    {
1174        let version = values
1175            .next()
1176            .ok_or_else(axum::headers::Error::invalid)?
1177            .to_str()
1178            .map_err(|_| axum::headers::Error::invalid())?
1179            .parse()
1180            .map_err(|_| axum::headers::Error::invalid())?;
1181        Ok(Self(version))
1182    }
1183
1184    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1185        values.extend([self.0.to_string().parse().unwrap()]);
1186    }
1187}
1188
1189#[derive(Debug)]
1190pub struct ReleaseChannelHeader(String);
1191
1192impl Header for ReleaseChannelHeader {
1193    fn name() -> &'static HeaderName {
1194        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1195        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1196    }
1197
1198    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1199    where
1200        Self: Sized,
1201        I: Iterator<Item = &'i axum::http::HeaderValue>,
1202    {
1203        Ok(Self(
1204            values
1205                .next()
1206                .ok_or_else(axum::headers::Error::invalid)?
1207                .to_str()
1208                .map_err(|_| axum::headers::Error::invalid())?
1209                .to_owned(),
1210        ))
1211    }
1212
1213    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1214        values.extend([self.0.parse().unwrap()]);
1215    }
1216}
1217
1218pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1219    Router::new()
1220        .route("/rpc", get(handle_websocket_request))
1221        .layer(
1222            ServiceBuilder::new()
1223                .layer(Extension(server.app_state.clone()))
1224                .layer(middleware::from_fn(auth::validate_header)),
1225        )
1226        .route("/metrics", get(handle_metrics))
1227        .layer(Extension(server))
1228}
1229
1230pub async fn handle_websocket_request(
1231    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1232    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1233    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1234    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1235    Extension(server): Extension<Arc<Server>>,
1236    Extension(principal): Extension<Principal>,
1237    user_agent: Option<TypedHeader<UserAgent>>,
1238    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1239    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1240    ws: WebSocketUpgrade,
1241) -> axum::response::Response {
1242    if protocol_version != rpc::PROTOCOL_VERSION {
1243        return (
1244            StatusCode::UPGRADE_REQUIRED,
1245            "client must be upgraded".to_string(),
1246        )
1247            .into_response();
1248    }
1249
1250    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1251        return (
1252            StatusCode::UPGRADE_REQUIRED,
1253            "no version header found".to_string(),
1254        )
1255            .into_response();
1256    };
1257
1258    let release_channel = release_channel_header.map(|header| header.0.0);
1259
1260    if !version.can_collaborate() {
1261        return (
1262            StatusCode::UPGRADE_REQUIRED,
1263            "client must be upgraded".to_string(),
1264        )
1265            .into_response();
1266    }
1267
1268    let socket_address = socket_address.to_string();
1269
1270    // Acquire connection guard before WebSocket upgrade
1271    let connection_guard = match ConnectionGuard::try_acquire() {
1272        Ok(guard) => guard,
1273        Err(()) => {
1274            return (
1275                StatusCode::SERVICE_UNAVAILABLE,
1276                "Too many concurrent connections",
1277            )
1278                .into_response();
1279        }
1280    };
1281
1282    ws.on_upgrade(move |socket| {
1283        let socket = socket
1284            .map_ok(to_tungstenite_message)
1285            .err_into()
1286            .with(|message| async move { to_axum_message(message) });
1287        let connection = Connection::new(Box::pin(socket));
1288        async move {
1289            server
1290                .handle_connection(
1291                    connection,
1292                    socket_address,
1293                    principal,
1294                    version,
1295                    release_channel,
1296                    user_agent.map(|header| header.to_string()),
1297                    country_code_header.map(|header| header.to_string()),
1298                    system_id_header.map(|header| header.to_string()),
1299                    None,
1300                    Executor::Production,
1301                    Some(connection_guard),
1302                )
1303                .await;
1304        }
1305    })
1306}
1307
1308pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1309    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1310    let connections_metric = CONNECTIONS_METRIC
1311        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1312
1313    let connections = server
1314        .connection_pool
1315        .lock()
1316        .connections()
1317        .filter(|connection| !connection.admin)
1318        .count();
1319    connections_metric.set(connections as _);
1320
1321    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1322    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1323        register_int_gauge!(
1324            "shared_projects",
1325            "number of open projects with one or more guests"
1326        )
1327        .unwrap()
1328    });
1329
1330    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1331    shared_projects_metric.set(shared_projects as _);
1332
1333    let encoder = prometheus::TextEncoder::new();
1334    let metric_families = prometheus::gather();
1335    let encoded_metrics = encoder
1336        .encode_to_string(&metric_families)
1337        .map_err(|err| anyhow!("{err}"))?;
1338    Ok(encoded_metrics)
1339}
1340
1341#[instrument(err, skip(executor))]
1342async fn connection_lost(
1343    session: Session,
1344    mut teardown: watch::Receiver<bool>,
1345    executor: Executor,
1346) -> Result<()> {
1347    session.peer.disconnect(session.connection_id);
1348    session
1349        .connection_pool()
1350        .await
1351        .remove_connection(session.connection_id)?;
1352
1353    session
1354        .db()
1355        .await
1356        .connection_lost(session.connection_id)
1357        .await
1358        .trace_err();
1359
1360    futures::select_biased! {
1361        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1362
1363            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1364            leave_room_for_session(&session, session.connection_id).await.trace_err();
1365            leave_channel_buffers_for_session(&session)
1366                .await
1367                .trace_err();
1368
1369            if !session
1370                .connection_pool()
1371                .await
1372                .is_user_online(session.user_id())
1373            {
1374                let db = session.db().await;
1375                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1376                    room_updated(&room, &session.peer);
1377                }
1378            }
1379
1380            update_user_contacts(session.user_id(), &session).await?;
1381        },
1382        _ = teardown.changed().fuse() => {}
1383    }
1384
1385    Ok(())
1386}
1387
1388/// Acknowledges a ping from a client, used to keep the connection alive.
1389async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1390    response.send(proto::Ack {})?;
1391    Ok(())
1392}
1393
1394/// Creates a new room for calling (outside of channels)
1395async fn create_room(
1396    _request: proto::CreateRoom,
1397    response: Response<proto::CreateRoom>,
1398    session: Session,
1399) -> Result<()> {
1400    let livekit_room = nanoid::nanoid!(30);
1401
1402    let live_kit_connection_info = util::maybe!(async {
1403        let live_kit = session.app_state.livekit_client.as_ref();
1404        let live_kit = live_kit?;
1405        let user_id = session.user_id().to_string();
1406
1407        let token = live_kit
1408            .room_token(&livekit_room, &user_id.to_string())
1409            .trace_err()?;
1410
1411        Some(proto::LiveKitConnectionInfo {
1412            server_url: live_kit.url().into(),
1413            token,
1414            can_publish: true,
1415        })
1416    })
1417    .await;
1418
1419    let room = session
1420        .db()
1421        .await
1422        .create_room(session.user_id(), session.connection_id, &livekit_room)
1423        .await?;
1424
1425    response.send(proto::CreateRoomResponse {
1426        room: Some(room.clone()),
1427        live_kit_connection_info,
1428    })?;
1429
1430    update_user_contacts(session.user_id(), &session).await?;
1431    Ok(())
1432}
1433
1434/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1435async fn join_room(
1436    request: proto::JoinRoom,
1437    response: Response<proto::JoinRoom>,
1438    session: Session,
1439) -> Result<()> {
1440    let room_id = RoomId::from_proto(request.id);
1441
1442    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1443
1444    if let Some(channel_id) = channel_id {
1445        return join_channel_internal(channel_id, Box::new(response), session).await;
1446    }
1447
1448    let joined_room = {
1449        let room = session
1450            .db()
1451            .await
1452            .join_room(room_id, session.user_id(), session.connection_id)
1453            .await?;
1454        room_updated(&room.room, &session.peer);
1455        room.into_inner()
1456    };
1457
1458    for connection_id in session
1459        .connection_pool()
1460        .await
1461        .user_connection_ids(session.user_id())
1462    {
1463        session
1464            .peer
1465            .send(
1466                connection_id,
1467                proto::CallCanceled {
1468                    room_id: room_id.to_proto(),
1469                },
1470            )
1471            .trace_err();
1472    }
1473
1474    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1475    {
1476        live_kit
1477            .room_token(
1478                &joined_room.room.livekit_room,
1479                &session.user_id().to_string(),
1480            )
1481            .trace_err()
1482            .map(|token| proto::LiveKitConnectionInfo {
1483                server_url: live_kit.url().into(),
1484                token,
1485                can_publish: true,
1486            })
1487    } else {
1488        None
1489    };
1490
1491    response.send(proto::JoinRoomResponse {
1492        room: Some(joined_room.room),
1493        channel_id: None,
1494        live_kit_connection_info,
1495    })?;
1496
1497    update_user_contacts(session.user_id(), &session).await?;
1498    Ok(())
1499}
1500
1501/// Rejoin room is used to reconnect to a room after connection errors.
1502async fn rejoin_room(
1503    request: proto::RejoinRoom,
1504    response: Response<proto::RejoinRoom>,
1505    session: Session,
1506) -> Result<()> {
1507    let room;
1508    let channel;
1509    {
1510        let mut rejoined_room = session
1511            .db()
1512            .await
1513            .rejoin_room(request, session.user_id(), session.connection_id)
1514            .await?;
1515
1516        response.send(proto::RejoinRoomResponse {
1517            room: Some(rejoined_room.room.clone()),
1518            reshared_projects: rejoined_room
1519                .reshared_projects
1520                .iter()
1521                .map(|project| proto::ResharedProject {
1522                    id: project.id.to_proto(),
1523                    collaborators: project
1524                        .collaborators
1525                        .iter()
1526                        .map(|collaborator| collaborator.to_proto())
1527                        .collect(),
1528                })
1529                .collect(),
1530            rejoined_projects: rejoined_room
1531                .rejoined_projects
1532                .iter()
1533                .map(|rejoined_project| rejoined_project.to_proto())
1534                .collect(),
1535        })?;
1536        room_updated(&rejoined_room.room, &session.peer);
1537
1538        for project in &rejoined_room.reshared_projects {
1539            for collaborator in &project.collaborators {
1540                session
1541                    .peer
1542                    .send(
1543                        collaborator.connection_id,
1544                        proto::UpdateProjectCollaborator {
1545                            project_id: project.id.to_proto(),
1546                            old_peer_id: Some(project.old_connection_id.into()),
1547                            new_peer_id: Some(session.connection_id.into()),
1548                        },
1549                    )
1550                    .trace_err();
1551            }
1552
1553            broadcast(
1554                Some(session.connection_id),
1555                project
1556                    .collaborators
1557                    .iter()
1558                    .map(|collaborator| collaborator.connection_id),
1559                |connection_id| {
1560                    session.peer.forward_send(
1561                        session.connection_id,
1562                        connection_id,
1563                        proto::UpdateProject {
1564                            project_id: project.id.to_proto(),
1565                            worktrees: project.worktrees.clone(),
1566                        },
1567                    )
1568                },
1569            );
1570        }
1571
1572        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1573
1574        let rejoined_room = rejoined_room.into_inner();
1575
1576        room = rejoined_room.room;
1577        channel = rejoined_room.channel;
1578    }
1579
1580    if let Some(channel) = channel {
1581        channel_updated(
1582            &channel,
1583            &room,
1584            &session.peer,
1585            &*session.connection_pool().await,
1586        );
1587    }
1588
1589    update_user_contacts(session.user_id(), &session).await?;
1590    Ok(())
1591}
1592
1593fn notify_rejoined_projects(
1594    rejoined_projects: &mut Vec<RejoinedProject>,
1595    session: &Session,
1596) -> Result<()> {
1597    for project in rejoined_projects.iter() {
1598        for collaborator in &project.collaborators {
1599            session
1600                .peer
1601                .send(
1602                    collaborator.connection_id,
1603                    proto::UpdateProjectCollaborator {
1604                        project_id: project.id.to_proto(),
1605                        old_peer_id: Some(project.old_connection_id.into()),
1606                        new_peer_id: Some(session.connection_id.into()),
1607                    },
1608                )
1609                .trace_err();
1610        }
1611    }
1612
1613    for project in rejoined_projects {
1614        for worktree in mem::take(&mut project.worktrees) {
1615            // Stream this worktree's entries.
1616            let message = proto::UpdateWorktree {
1617                project_id: project.id.to_proto(),
1618                worktree_id: worktree.id,
1619                abs_path: worktree.abs_path.clone(),
1620                root_name: worktree.root_name,
1621                updated_entries: worktree.updated_entries,
1622                removed_entries: worktree.removed_entries,
1623                scan_id: worktree.scan_id,
1624                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1625                updated_repositories: worktree.updated_repositories,
1626                removed_repositories: worktree.removed_repositories,
1627            };
1628            for update in proto::split_worktree_update(message) {
1629                session.peer.send(session.connection_id, update)?;
1630            }
1631
1632            // Stream this worktree's diagnostics.
1633            for summary in worktree.diagnostic_summaries {
1634                session.peer.send(
1635                    session.connection_id,
1636                    proto::UpdateDiagnosticSummary {
1637                        project_id: project.id.to_proto(),
1638                        worktree_id: worktree.id,
1639                        summary: Some(summary),
1640                    },
1641                )?;
1642            }
1643
1644            for settings_file in worktree.settings_files {
1645                session.peer.send(
1646                    session.connection_id,
1647                    proto::UpdateWorktreeSettings {
1648                        project_id: project.id.to_proto(),
1649                        worktree_id: worktree.id,
1650                        path: settings_file.path,
1651                        content: Some(settings_file.content),
1652                        kind: Some(settings_file.kind.to_proto().into()),
1653                    },
1654                )?;
1655            }
1656        }
1657
1658        for repository in mem::take(&mut project.updated_repositories) {
1659            for update in split_repository_update(repository) {
1660                session.peer.send(session.connection_id, update)?;
1661            }
1662        }
1663
1664        for id in mem::take(&mut project.removed_repositories) {
1665            session.peer.send(
1666                session.connection_id,
1667                proto::RemoveRepository {
1668                    project_id: project.id.to_proto(),
1669                    id,
1670                },
1671            )?;
1672        }
1673    }
1674
1675    Ok(())
1676}
1677
1678/// leave room disconnects from the room.
1679async fn leave_room(
1680    _: proto::LeaveRoom,
1681    response: Response<proto::LeaveRoom>,
1682    session: Session,
1683) -> Result<()> {
1684    leave_room_for_session(&session, session.connection_id).await?;
1685    response.send(proto::Ack {})?;
1686    Ok(())
1687}
1688
1689/// Updates the permissions of someone else in the room.
1690async fn set_room_participant_role(
1691    request: proto::SetRoomParticipantRole,
1692    response: Response<proto::SetRoomParticipantRole>,
1693    session: Session,
1694) -> Result<()> {
1695    let user_id = UserId::from_proto(request.user_id);
1696    let role = ChannelRole::from(request.role());
1697
1698    let (livekit_room, can_publish) = {
1699        let room = session
1700            .db()
1701            .await
1702            .set_room_participant_role(
1703                session.user_id(),
1704                RoomId::from_proto(request.room_id),
1705                user_id,
1706                role,
1707            )
1708            .await?;
1709
1710        let livekit_room = room.livekit_room.clone();
1711        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1712        room_updated(&room, &session.peer);
1713        (livekit_room, can_publish)
1714    };
1715
1716    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1717        live_kit
1718            .update_participant(
1719                livekit_room.clone(),
1720                request.user_id.to_string(),
1721                livekit_api::proto::ParticipantPermission {
1722                    can_subscribe: true,
1723                    can_publish,
1724                    can_publish_data: can_publish,
1725                    hidden: false,
1726                    recorder: false,
1727                },
1728            )
1729            .await
1730            .trace_err();
1731    }
1732
1733    response.send(proto::Ack {})?;
1734    Ok(())
1735}
1736
1737/// Call someone else into the current room
1738async fn call(
1739    request: proto::Call,
1740    response: Response<proto::Call>,
1741    session: Session,
1742) -> Result<()> {
1743    let room_id = RoomId::from_proto(request.room_id);
1744    let calling_user_id = session.user_id();
1745    let calling_connection_id = session.connection_id;
1746    let called_user_id = UserId::from_proto(request.called_user_id);
1747    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1748    if !session
1749        .db()
1750        .await
1751        .has_contact(calling_user_id, called_user_id)
1752        .await?
1753    {
1754        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1755    }
1756
1757    let incoming_call = {
1758        let (room, incoming_call) = &mut *session
1759            .db()
1760            .await
1761            .call(
1762                room_id,
1763                calling_user_id,
1764                calling_connection_id,
1765                called_user_id,
1766                initial_project_id,
1767            )
1768            .await?;
1769        room_updated(room, &session.peer);
1770        mem::take(incoming_call)
1771    };
1772    update_user_contacts(called_user_id, &session).await?;
1773
1774    let mut calls = session
1775        .connection_pool()
1776        .await
1777        .user_connection_ids(called_user_id)
1778        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1779        .collect::<FuturesUnordered<_>>();
1780
1781    while let Some(call_response) = calls.next().await {
1782        match call_response.as_ref() {
1783            Ok(_) => {
1784                response.send(proto::Ack {})?;
1785                return Ok(());
1786            }
1787            Err(_) => {
1788                call_response.trace_err();
1789            }
1790        }
1791    }
1792
1793    {
1794        let room = session
1795            .db()
1796            .await
1797            .call_failed(room_id, called_user_id)
1798            .await?;
1799        room_updated(&room, &session.peer);
1800    }
1801    update_user_contacts(called_user_id, &session).await?;
1802
1803    Err(anyhow!("failed to ring user"))?
1804}
1805
1806/// Cancel an outgoing call.
1807async fn cancel_call(
1808    request: proto::CancelCall,
1809    response: Response<proto::CancelCall>,
1810    session: Session,
1811) -> Result<()> {
1812    let called_user_id = UserId::from_proto(request.called_user_id);
1813    let room_id = RoomId::from_proto(request.room_id);
1814    {
1815        let room = session
1816            .db()
1817            .await
1818            .cancel_call(room_id, session.connection_id, called_user_id)
1819            .await?;
1820        room_updated(&room, &session.peer);
1821    }
1822
1823    for connection_id in session
1824        .connection_pool()
1825        .await
1826        .user_connection_ids(called_user_id)
1827    {
1828        session
1829            .peer
1830            .send(
1831                connection_id,
1832                proto::CallCanceled {
1833                    room_id: room_id.to_proto(),
1834                },
1835            )
1836            .trace_err();
1837    }
1838    response.send(proto::Ack {})?;
1839
1840    update_user_contacts(called_user_id, &session).await?;
1841    Ok(())
1842}
1843
1844/// Decline an incoming call.
1845async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1846    let room_id = RoomId::from_proto(message.room_id);
1847    {
1848        let room = session
1849            .db()
1850            .await
1851            .decline_call(Some(room_id), session.user_id())
1852            .await?
1853            .context("declining call")?;
1854        room_updated(&room, &session.peer);
1855    }
1856
1857    for connection_id in session
1858        .connection_pool()
1859        .await
1860        .user_connection_ids(session.user_id())
1861    {
1862        session
1863            .peer
1864            .send(
1865                connection_id,
1866                proto::CallCanceled {
1867                    room_id: room_id.to_proto(),
1868                },
1869            )
1870            .trace_err();
1871    }
1872    update_user_contacts(session.user_id(), &session).await?;
1873    Ok(())
1874}
1875
1876/// Updates other participants in the room with your current location.
1877async fn update_participant_location(
1878    request: proto::UpdateParticipantLocation,
1879    response: Response<proto::UpdateParticipantLocation>,
1880    session: Session,
1881) -> Result<()> {
1882    let room_id = RoomId::from_proto(request.room_id);
1883    let location = request.location.context("invalid location")?;
1884
1885    let db = session.db().await;
1886    let room = db
1887        .update_room_participant_location(room_id, session.connection_id, location)
1888        .await?;
1889
1890    room_updated(&room, &session.peer);
1891    response.send(proto::Ack {})?;
1892    Ok(())
1893}
1894
1895/// Share a project into the room.
1896async fn share_project(
1897    request: proto::ShareProject,
1898    response: Response<proto::ShareProject>,
1899    session: Session,
1900) -> Result<()> {
1901    let (project_id, room) = &*session
1902        .db()
1903        .await
1904        .share_project(
1905            RoomId::from_proto(request.room_id),
1906            session.connection_id,
1907            &request.worktrees,
1908            request.is_ssh_project,
1909        )
1910        .await?;
1911    response.send(proto::ShareProjectResponse {
1912        project_id: project_id.to_proto(),
1913    })?;
1914    room_updated(room, &session.peer);
1915
1916    Ok(())
1917}
1918
1919/// Unshare a project from the room.
1920async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1921    let project_id = ProjectId::from_proto(message.project_id);
1922    unshare_project_internal(project_id, session.connection_id, &session).await
1923}
1924
1925async fn unshare_project_internal(
1926    project_id: ProjectId,
1927    connection_id: ConnectionId,
1928    session: &Session,
1929) -> Result<()> {
1930    let delete = {
1931        let room_guard = session
1932            .db()
1933            .await
1934            .unshare_project(project_id, connection_id)
1935            .await?;
1936
1937        let (delete, room, guest_connection_ids) = &*room_guard;
1938
1939        let message = proto::UnshareProject {
1940            project_id: project_id.to_proto(),
1941        };
1942
1943        broadcast(
1944            Some(connection_id),
1945            guest_connection_ids.iter().copied(),
1946            |conn_id| session.peer.send(conn_id, message.clone()),
1947        );
1948        if let Some(room) = room {
1949            room_updated(room, &session.peer);
1950        }
1951
1952        *delete
1953    };
1954
1955    if delete {
1956        let db = session.db().await;
1957        db.delete_project(project_id).await?;
1958    }
1959
1960    Ok(())
1961}
1962
1963/// Join someone elses shared project.
1964async fn join_project(
1965    request: proto::JoinProject,
1966    response: Response<proto::JoinProject>,
1967    session: Session,
1968) -> Result<()> {
1969    let project_id = ProjectId::from_proto(request.project_id);
1970
1971    tracing::info!(%project_id, "join project");
1972
1973    let db = session.db().await;
1974    let (project, replica_id) = &mut *db
1975        .join_project(
1976            project_id,
1977            session.connection_id,
1978            session.user_id(),
1979            request.committer_name.clone(),
1980            request.committer_email.clone(),
1981        )
1982        .await?;
1983    drop(db);
1984    tracing::info!(%project_id, "join remote project");
1985    let collaborators = project
1986        .collaborators
1987        .iter()
1988        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1989        .map(|collaborator| collaborator.to_proto())
1990        .collect::<Vec<_>>();
1991    let project_id = project.id;
1992    let guest_user_id = session.user_id();
1993
1994    let worktrees = project
1995        .worktrees
1996        .iter()
1997        .map(|(id, worktree)| proto::WorktreeMetadata {
1998            id: *id,
1999            root_name: worktree.root_name.clone(),
2000            visible: worktree.visible,
2001            abs_path: worktree.abs_path.clone(),
2002        })
2003        .collect::<Vec<_>>();
2004
2005    let add_project_collaborator = proto::AddProjectCollaborator {
2006        project_id: project_id.to_proto(),
2007        collaborator: Some(proto::Collaborator {
2008            peer_id: Some(session.connection_id.into()),
2009            replica_id: replica_id.0 as u32,
2010            user_id: guest_user_id.to_proto(),
2011            is_host: false,
2012            committer_name: request.committer_name.clone(),
2013            committer_email: request.committer_email.clone(),
2014        }),
2015    };
2016
2017    for collaborator in &collaborators {
2018        session
2019            .peer
2020            .send(
2021                collaborator.peer_id.unwrap().into(),
2022                add_project_collaborator.clone(),
2023            )
2024            .trace_err();
2025    }
2026
2027    // First, we send the metadata associated with each worktree.
2028    let (language_servers, language_server_capabilities) = project
2029        .language_servers
2030        .clone()
2031        .into_iter()
2032        .map(|server| (server.server, server.capabilities))
2033        .unzip();
2034    response.send(proto::JoinProjectResponse {
2035        project_id: project.id.0 as u64,
2036        worktrees: worktrees.clone(),
2037        replica_id: replica_id.0 as u32,
2038        collaborators: collaborators.clone(),
2039        language_servers,
2040        language_server_capabilities,
2041        role: project.role.into(),
2042    })?;
2043
2044    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2045        // Stream this worktree's entries.
2046        let message = proto::UpdateWorktree {
2047            project_id: project_id.to_proto(),
2048            worktree_id,
2049            abs_path: worktree.abs_path.clone(),
2050            root_name: worktree.root_name,
2051            updated_entries: worktree.entries,
2052            removed_entries: Default::default(),
2053            scan_id: worktree.scan_id,
2054            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2055            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2056            removed_repositories: Default::default(),
2057        };
2058        for update in proto::split_worktree_update(message) {
2059            session.peer.send(session.connection_id, update.clone())?;
2060        }
2061
2062        // Stream this worktree's diagnostics.
2063        for summary in worktree.diagnostic_summaries {
2064            session.peer.send(
2065                session.connection_id,
2066                proto::UpdateDiagnosticSummary {
2067                    project_id: project_id.to_proto(),
2068                    worktree_id: worktree.id,
2069                    summary: Some(summary),
2070                },
2071            )?;
2072        }
2073
2074        for settings_file in worktree.settings_files {
2075            session.peer.send(
2076                session.connection_id,
2077                proto::UpdateWorktreeSettings {
2078                    project_id: project_id.to_proto(),
2079                    worktree_id: worktree.id,
2080                    path: settings_file.path,
2081                    content: Some(settings_file.content),
2082                    kind: Some(settings_file.kind.to_proto() as i32),
2083                },
2084            )?;
2085        }
2086    }
2087
2088    for repository in mem::take(&mut project.repositories) {
2089        for update in split_repository_update(repository) {
2090            session.peer.send(session.connection_id, update)?;
2091        }
2092    }
2093
2094    for language_server in &project.language_servers {
2095        session.peer.send(
2096            session.connection_id,
2097            proto::UpdateLanguageServer {
2098                project_id: project_id.to_proto(),
2099                server_name: Some(language_server.server.name.clone()),
2100                language_server_id: language_server.server.id,
2101                variant: Some(
2102                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2103                        proto::LspDiskBasedDiagnosticsUpdated {},
2104                    ),
2105                ),
2106            },
2107        )?;
2108    }
2109
2110    Ok(())
2111}
2112
2113/// Leave someone elses shared project.
2114async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2115    let sender_id = session.connection_id;
2116    let project_id = ProjectId::from_proto(request.project_id);
2117    let db = session.db().await;
2118
2119    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2120    tracing::info!(
2121        %project_id,
2122        "leave project"
2123    );
2124
2125    project_left(project, &session);
2126    if let Some(room) = room {
2127        room_updated(room, &session.peer);
2128    }
2129
2130    Ok(())
2131}
2132
2133/// Updates other participants with changes to the project
2134async fn update_project(
2135    request: proto::UpdateProject,
2136    response: Response<proto::UpdateProject>,
2137    session: Session,
2138) -> Result<()> {
2139    let project_id = ProjectId::from_proto(request.project_id);
2140    let (room, guest_connection_ids) = &*session
2141        .db()
2142        .await
2143        .update_project(project_id, session.connection_id, &request.worktrees)
2144        .await?;
2145    broadcast(
2146        Some(session.connection_id),
2147        guest_connection_ids.iter().copied(),
2148        |connection_id| {
2149            session
2150                .peer
2151                .forward_send(session.connection_id, connection_id, request.clone())
2152        },
2153    );
2154    if let Some(room) = room {
2155        room_updated(room, &session.peer);
2156    }
2157    response.send(proto::Ack {})?;
2158
2159    Ok(())
2160}
2161
2162/// Updates other participants with changes to the worktree
2163async fn update_worktree(
2164    request: proto::UpdateWorktree,
2165    response: Response<proto::UpdateWorktree>,
2166    session: Session,
2167) -> Result<()> {
2168    let guest_connection_ids = session
2169        .db()
2170        .await
2171        .update_worktree(&request, session.connection_id)
2172        .await?;
2173
2174    broadcast(
2175        Some(session.connection_id),
2176        guest_connection_ids.iter().copied(),
2177        |connection_id| {
2178            session
2179                .peer
2180                .forward_send(session.connection_id, connection_id, request.clone())
2181        },
2182    );
2183    response.send(proto::Ack {})?;
2184    Ok(())
2185}
2186
2187async fn update_repository(
2188    request: proto::UpdateRepository,
2189    response: Response<proto::UpdateRepository>,
2190    session: Session,
2191) -> Result<()> {
2192    let guest_connection_ids = session
2193        .db()
2194        .await
2195        .update_repository(&request, session.connection_id)
2196        .await?;
2197
2198    broadcast(
2199        Some(session.connection_id),
2200        guest_connection_ids.iter().copied(),
2201        |connection_id| {
2202            session
2203                .peer
2204                .forward_send(session.connection_id, connection_id, request.clone())
2205        },
2206    );
2207    response.send(proto::Ack {})?;
2208    Ok(())
2209}
2210
2211async fn remove_repository(
2212    request: proto::RemoveRepository,
2213    response: Response<proto::RemoveRepository>,
2214    session: Session,
2215) -> Result<()> {
2216    let guest_connection_ids = session
2217        .db()
2218        .await
2219        .remove_repository(&request, session.connection_id)
2220        .await?;
2221
2222    broadcast(
2223        Some(session.connection_id),
2224        guest_connection_ids.iter().copied(),
2225        |connection_id| {
2226            session
2227                .peer
2228                .forward_send(session.connection_id, connection_id, request.clone())
2229        },
2230    );
2231    response.send(proto::Ack {})?;
2232    Ok(())
2233}
2234
2235/// Updates other participants with changes to the diagnostics
2236async fn update_diagnostic_summary(
2237    message: proto::UpdateDiagnosticSummary,
2238    session: Session,
2239) -> Result<()> {
2240    let guest_connection_ids = session
2241        .db()
2242        .await
2243        .update_diagnostic_summary(&message, session.connection_id)
2244        .await?;
2245
2246    broadcast(
2247        Some(session.connection_id),
2248        guest_connection_ids.iter().copied(),
2249        |connection_id| {
2250            session
2251                .peer
2252                .forward_send(session.connection_id, connection_id, message.clone())
2253        },
2254    );
2255
2256    Ok(())
2257}
2258
2259/// Updates other participants with changes to the worktree settings
2260async fn update_worktree_settings(
2261    message: proto::UpdateWorktreeSettings,
2262    session: Session,
2263) -> Result<()> {
2264    let guest_connection_ids = session
2265        .db()
2266        .await
2267        .update_worktree_settings(&message, session.connection_id)
2268        .await?;
2269
2270    broadcast(
2271        Some(session.connection_id),
2272        guest_connection_ids.iter().copied(),
2273        |connection_id| {
2274            session
2275                .peer
2276                .forward_send(session.connection_id, connection_id, message.clone())
2277        },
2278    );
2279
2280    Ok(())
2281}
2282
2283/// Notify other participants that a language server has started.
2284async fn start_language_server(
2285    request: proto::StartLanguageServer,
2286    session: Session,
2287) -> Result<()> {
2288    let guest_connection_ids = session
2289        .db()
2290        .await
2291        .start_language_server(&request, session.connection_id)
2292        .await?;
2293
2294    broadcast(
2295        Some(session.connection_id),
2296        guest_connection_ids.iter().copied(),
2297        |connection_id| {
2298            session
2299                .peer
2300                .forward_send(session.connection_id, connection_id, request.clone())
2301        },
2302    );
2303    Ok(())
2304}
2305
2306/// Notify other participants that a language server has changed.
2307async fn update_language_server(
2308    request: proto::UpdateLanguageServer,
2309    session: Session,
2310) -> Result<()> {
2311    let project_id = ProjectId::from_proto(request.project_id);
2312    let db = session.db().await;
2313
2314    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2315    {
2316        if let Some(capabilities) = update.capabilities.clone() {
2317            db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2318                .await?;
2319        }
2320    }
2321
2322    let project_connection_ids = db
2323        .project_connection_ids(project_id, session.connection_id, true)
2324        .await?;
2325    broadcast(
2326        Some(session.connection_id),
2327        project_connection_ids.iter().copied(),
2328        |connection_id| {
2329            session
2330                .peer
2331                .forward_send(session.connection_id, connection_id, request.clone())
2332        },
2333    );
2334    Ok(())
2335}
2336
2337/// forward a project request to the host. These requests should be read only
2338/// as guests are allowed to send them.
2339async fn forward_read_only_project_request<T>(
2340    request: T,
2341    response: Response<T>,
2342    session: Session,
2343) -> Result<()>
2344where
2345    T: EntityMessage + RequestMessage,
2346{
2347    let project_id = ProjectId::from_proto(request.remote_entity_id());
2348    let host_connection_id = session
2349        .db()
2350        .await
2351        .host_for_read_only_project_request(project_id, session.connection_id)
2352        .await?;
2353    let payload = session
2354        .peer
2355        .forward_request(session.connection_id, host_connection_id, request)
2356        .await?;
2357    response.send(payload)?;
2358    Ok(())
2359}
2360
2361/// forward a project request to the host. These requests are disallowed
2362/// for guests.
2363async fn forward_mutating_project_request<T>(
2364    request: T,
2365    response: Response<T>,
2366    session: Session,
2367) -> Result<()>
2368where
2369    T: EntityMessage + RequestMessage,
2370{
2371    let project_id = ProjectId::from_proto(request.remote_entity_id());
2372
2373    let host_connection_id = session
2374        .db()
2375        .await
2376        .host_for_mutating_project_request(project_id, session.connection_id)
2377        .await?;
2378    let payload = session
2379        .peer
2380        .forward_request(session.connection_id, host_connection_id, request)
2381        .await?;
2382    response.send(payload)?;
2383    Ok(())
2384}
2385
2386async fn multi_lsp_query(
2387    request: MultiLspQuery,
2388    response: Response<MultiLspQuery>,
2389    session: Session,
2390) -> Result<()> {
2391    tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2392    tracing::info!("multi_lsp_query message received");
2393    forward_mutating_project_request(request, response, session).await
2394}
2395
2396/// Notify other participants that a new buffer has been created
2397async fn create_buffer_for_peer(
2398    request: proto::CreateBufferForPeer,
2399    session: Session,
2400) -> Result<()> {
2401    session
2402        .db()
2403        .await
2404        .check_user_is_project_host(
2405            ProjectId::from_proto(request.project_id),
2406            session.connection_id,
2407        )
2408        .await?;
2409    let peer_id = request.peer_id.context("invalid peer id")?;
2410    session
2411        .peer
2412        .forward_send(session.connection_id, peer_id.into(), request)?;
2413    Ok(())
2414}
2415
2416/// Notify other participants that a buffer has been updated. This is
2417/// allowed for guests as long as the update is limited to selections.
2418async fn update_buffer(
2419    request: proto::UpdateBuffer,
2420    response: Response<proto::UpdateBuffer>,
2421    session: Session,
2422) -> Result<()> {
2423    let project_id = ProjectId::from_proto(request.project_id);
2424    let mut capability = Capability::ReadOnly;
2425
2426    for op in request.operations.iter() {
2427        match op.variant {
2428            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2429            Some(_) => capability = Capability::ReadWrite,
2430        }
2431    }
2432
2433    let host = {
2434        let guard = session
2435            .db()
2436            .await
2437            .connections_for_buffer_update(project_id, session.connection_id, capability)
2438            .await?;
2439
2440        let (host, guests) = &*guard;
2441
2442        broadcast(
2443            Some(session.connection_id),
2444            guests.clone(),
2445            |connection_id| {
2446                session
2447                    .peer
2448                    .forward_send(session.connection_id, connection_id, request.clone())
2449            },
2450        );
2451
2452        *host
2453    };
2454
2455    if host != session.connection_id {
2456        session
2457            .peer
2458            .forward_request(session.connection_id, host, request.clone())
2459            .await?;
2460    }
2461
2462    response.send(proto::Ack {})?;
2463    Ok(())
2464}
2465
2466async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2467    let project_id = ProjectId::from_proto(message.project_id);
2468
2469    let operation = message.operation.as_ref().context("invalid operation")?;
2470    let capability = match operation.variant.as_ref() {
2471        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2472            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2473                match buffer_op.variant {
2474                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2475                        Capability::ReadOnly
2476                    }
2477                    _ => Capability::ReadWrite,
2478                }
2479            } else {
2480                Capability::ReadWrite
2481            }
2482        }
2483        Some(_) => Capability::ReadWrite,
2484        None => Capability::ReadOnly,
2485    };
2486
2487    let guard = session
2488        .db()
2489        .await
2490        .connections_for_buffer_update(project_id, session.connection_id, capability)
2491        .await?;
2492
2493    let (host, guests) = &*guard;
2494
2495    broadcast(
2496        Some(session.connection_id),
2497        guests.iter().chain([host]).copied(),
2498        |connection_id| {
2499            session
2500                .peer
2501                .forward_send(session.connection_id, connection_id, message.clone())
2502        },
2503    );
2504
2505    Ok(())
2506}
2507
2508/// Notify other participants that a project has been updated.
2509async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2510    request: T,
2511    session: Session,
2512) -> Result<()> {
2513    let project_id = ProjectId::from_proto(request.remote_entity_id());
2514    let project_connection_ids = session
2515        .db()
2516        .await
2517        .project_connection_ids(project_id, session.connection_id, false)
2518        .await?;
2519
2520    broadcast(
2521        Some(session.connection_id),
2522        project_connection_ids.iter().copied(),
2523        |connection_id| {
2524            session
2525                .peer
2526                .forward_send(session.connection_id, connection_id, request.clone())
2527        },
2528    );
2529    Ok(())
2530}
2531
2532/// Start following another user in a call.
2533async fn follow(
2534    request: proto::Follow,
2535    response: Response<proto::Follow>,
2536    session: Session,
2537) -> Result<()> {
2538    let room_id = RoomId::from_proto(request.room_id);
2539    let project_id = request.project_id.map(ProjectId::from_proto);
2540    let leader_id = request.leader_id.context("invalid leader id")?.into();
2541    let follower_id = session.connection_id;
2542
2543    session
2544        .db()
2545        .await
2546        .check_room_participants(room_id, leader_id, session.connection_id)
2547        .await?;
2548
2549    let response_payload = session
2550        .peer
2551        .forward_request(session.connection_id, leader_id, request)
2552        .await?;
2553    response.send(response_payload)?;
2554
2555    if let Some(project_id) = project_id {
2556        let room = session
2557            .db()
2558            .await
2559            .follow(room_id, project_id, leader_id, follower_id)
2560            .await?;
2561        room_updated(&room, &session.peer);
2562    }
2563
2564    Ok(())
2565}
2566
2567/// Stop following another user in a call.
2568async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2569    let room_id = RoomId::from_proto(request.room_id);
2570    let project_id = request.project_id.map(ProjectId::from_proto);
2571    let leader_id = request.leader_id.context("invalid leader id")?.into();
2572    let follower_id = session.connection_id;
2573
2574    session
2575        .db()
2576        .await
2577        .check_room_participants(room_id, leader_id, session.connection_id)
2578        .await?;
2579
2580    session
2581        .peer
2582        .forward_send(session.connection_id, leader_id, request)?;
2583
2584    if let Some(project_id) = project_id {
2585        let room = session
2586            .db()
2587            .await
2588            .unfollow(room_id, project_id, leader_id, follower_id)
2589            .await?;
2590        room_updated(&room, &session.peer);
2591    }
2592
2593    Ok(())
2594}
2595
2596/// Notify everyone following you of your current location.
2597async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2598    let room_id = RoomId::from_proto(request.room_id);
2599    let database = session.db.lock().await;
2600
2601    let connection_ids = if let Some(project_id) = request.project_id {
2602        let project_id = ProjectId::from_proto(project_id);
2603        database
2604            .project_connection_ids(project_id, session.connection_id, true)
2605            .await?
2606    } else {
2607        database
2608            .room_connection_ids(room_id, session.connection_id)
2609            .await?
2610    };
2611
2612    // For now, don't send view update messages back to that view's current leader.
2613    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2614        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2615        _ => None,
2616    });
2617
2618    for connection_id in connection_ids.iter().cloned() {
2619        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2620            session
2621                .peer
2622                .forward_send(session.connection_id, connection_id, request.clone())?;
2623        }
2624    }
2625    Ok(())
2626}
2627
2628/// Get public data about users.
2629async fn get_users(
2630    request: proto::GetUsers,
2631    response: Response<proto::GetUsers>,
2632    session: Session,
2633) -> Result<()> {
2634    let user_ids = request
2635        .user_ids
2636        .into_iter()
2637        .map(UserId::from_proto)
2638        .collect();
2639    let users = session
2640        .db()
2641        .await
2642        .get_users_by_ids(user_ids)
2643        .await?
2644        .into_iter()
2645        .map(|user| proto::User {
2646            id: user.id.to_proto(),
2647            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2648            github_login: user.github_login,
2649            name: user.name,
2650        })
2651        .collect();
2652    response.send(proto::UsersResponse { users })?;
2653    Ok(())
2654}
2655
2656/// Search for users (to invite) buy Github login
2657async fn fuzzy_search_users(
2658    request: proto::FuzzySearchUsers,
2659    response: Response<proto::FuzzySearchUsers>,
2660    session: Session,
2661) -> Result<()> {
2662    let query = request.query;
2663    let users = match query.len() {
2664        0 => vec![],
2665        1 | 2 => session
2666            .db()
2667            .await
2668            .get_user_by_github_login(&query)
2669            .await?
2670            .into_iter()
2671            .collect(),
2672        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2673    };
2674    let users = users
2675        .into_iter()
2676        .filter(|user| user.id != session.user_id())
2677        .map(|user| proto::User {
2678            id: user.id.to_proto(),
2679            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2680            github_login: user.github_login,
2681            name: user.name,
2682        })
2683        .collect();
2684    response.send(proto::UsersResponse { users })?;
2685    Ok(())
2686}
2687
2688/// Send a contact request to another user.
2689async fn request_contact(
2690    request: proto::RequestContact,
2691    response: Response<proto::RequestContact>,
2692    session: Session,
2693) -> Result<()> {
2694    let requester_id = session.user_id();
2695    let responder_id = UserId::from_proto(request.responder_id);
2696    if requester_id == responder_id {
2697        return Err(anyhow!("cannot add yourself as a contact"))?;
2698    }
2699
2700    let notifications = session
2701        .db()
2702        .await
2703        .send_contact_request(requester_id, responder_id)
2704        .await?;
2705
2706    // Update outgoing contact requests of requester
2707    let mut update = proto::UpdateContacts::default();
2708    update.outgoing_requests.push(responder_id.to_proto());
2709    for connection_id in session
2710        .connection_pool()
2711        .await
2712        .user_connection_ids(requester_id)
2713    {
2714        session.peer.send(connection_id, update.clone())?;
2715    }
2716
2717    // Update incoming contact requests of responder
2718    let mut update = proto::UpdateContacts::default();
2719    update
2720        .incoming_requests
2721        .push(proto::IncomingContactRequest {
2722            requester_id: requester_id.to_proto(),
2723        });
2724    let connection_pool = session.connection_pool().await;
2725    for connection_id in connection_pool.user_connection_ids(responder_id) {
2726        session.peer.send(connection_id, update.clone())?;
2727    }
2728
2729    send_notifications(&connection_pool, &session.peer, notifications);
2730
2731    response.send(proto::Ack {})?;
2732    Ok(())
2733}
2734
2735/// Accept or decline a contact request
2736async fn respond_to_contact_request(
2737    request: proto::RespondToContactRequest,
2738    response: Response<proto::RespondToContactRequest>,
2739    session: Session,
2740) -> Result<()> {
2741    let responder_id = session.user_id();
2742    let requester_id = UserId::from_proto(request.requester_id);
2743    let db = session.db().await;
2744    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2745        db.dismiss_contact_notification(responder_id, requester_id)
2746            .await?;
2747    } else {
2748        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2749
2750        let notifications = db
2751            .respond_to_contact_request(responder_id, requester_id, accept)
2752            .await?;
2753        let requester_busy = db.is_user_busy(requester_id).await?;
2754        let responder_busy = db.is_user_busy(responder_id).await?;
2755
2756        let pool = session.connection_pool().await;
2757        // Update responder with new contact
2758        let mut update = proto::UpdateContacts::default();
2759        if accept {
2760            update
2761                .contacts
2762                .push(contact_for_user(requester_id, requester_busy, &pool));
2763        }
2764        update
2765            .remove_incoming_requests
2766            .push(requester_id.to_proto());
2767        for connection_id in pool.user_connection_ids(responder_id) {
2768            session.peer.send(connection_id, update.clone())?;
2769        }
2770
2771        // Update requester with new contact
2772        let mut update = proto::UpdateContacts::default();
2773        if accept {
2774            update
2775                .contacts
2776                .push(contact_for_user(responder_id, responder_busy, &pool));
2777        }
2778        update
2779            .remove_outgoing_requests
2780            .push(responder_id.to_proto());
2781
2782        for connection_id in pool.user_connection_ids(requester_id) {
2783            session.peer.send(connection_id, update.clone())?;
2784        }
2785
2786        send_notifications(&pool, &session.peer, notifications);
2787    }
2788
2789    response.send(proto::Ack {})?;
2790    Ok(())
2791}
2792
2793/// Remove a contact.
2794async fn remove_contact(
2795    request: proto::RemoveContact,
2796    response: Response<proto::RemoveContact>,
2797    session: Session,
2798) -> Result<()> {
2799    let requester_id = session.user_id();
2800    let responder_id = UserId::from_proto(request.user_id);
2801    let db = session.db().await;
2802    let (contact_accepted, deleted_notification_id) =
2803        db.remove_contact(requester_id, responder_id).await?;
2804
2805    let pool = session.connection_pool().await;
2806    // Update outgoing contact requests of requester
2807    let mut update = proto::UpdateContacts::default();
2808    if contact_accepted {
2809        update.remove_contacts.push(responder_id.to_proto());
2810    } else {
2811        update
2812            .remove_outgoing_requests
2813            .push(responder_id.to_proto());
2814    }
2815    for connection_id in pool.user_connection_ids(requester_id) {
2816        session.peer.send(connection_id, update.clone())?;
2817    }
2818
2819    // Update incoming contact requests of responder
2820    let mut update = proto::UpdateContacts::default();
2821    if contact_accepted {
2822        update.remove_contacts.push(requester_id.to_proto());
2823    } else {
2824        update
2825            .remove_incoming_requests
2826            .push(requester_id.to_proto());
2827    }
2828    for connection_id in pool.user_connection_ids(responder_id) {
2829        session.peer.send(connection_id, update.clone())?;
2830        if let Some(notification_id) = deleted_notification_id {
2831            session.peer.send(
2832                connection_id,
2833                proto::DeleteNotification {
2834                    notification_id: notification_id.to_proto(),
2835                },
2836            )?;
2837        }
2838    }
2839
2840    response.send(proto::Ack {})?;
2841    Ok(())
2842}
2843
2844fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2845    version.0.minor() < 139
2846}
2847
2848async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2849    if is_staff {
2850        return Ok(proto::Plan::ZedPro);
2851    }
2852
2853    let subscription = db.get_active_billing_subscription(user_id).await?;
2854    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2855
2856    let plan = if let Some(subscription_kind) = subscription_kind {
2857        match subscription_kind {
2858            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2859            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2860            SubscriptionKind::ZedFree => proto::Plan::Free,
2861        }
2862    } else {
2863        proto::Plan::Free
2864    };
2865
2866    Ok(plan)
2867}
2868
2869async fn make_update_user_plan_message(
2870    user: &User,
2871    is_staff: bool,
2872    db: &Arc<Database>,
2873    llm_db: Option<Arc<LlmDatabase>>,
2874) -> Result<proto::UpdateUserPlan> {
2875    let feature_flags = db.get_user_flags(user.id).await?;
2876    let plan = current_plan(db, user.id, is_staff).await?;
2877    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2878    let billing_preferences = db.get_billing_preferences(user.id).await?;
2879
2880    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2881        let subscription = db.get_active_billing_subscription(user.id).await?;
2882
2883        let subscription_period =
2884            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2885
2886        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2887            llm_db
2888                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2889                .await?
2890        } else {
2891            None
2892        };
2893
2894        (subscription_period, usage)
2895    } else {
2896        (None, None)
2897    };
2898
2899    let bypass_account_age_check = feature_flags
2900        .iter()
2901        .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2902    let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2903        && !bypass_account_age_check
2904        && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2905
2906    Ok(proto::UpdateUserPlan {
2907        plan: plan.into(),
2908        trial_started_at: billing_customer
2909            .as_ref()
2910            .and_then(|billing_customer| billing_customer.trial_started_at)
2911            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2912        is_usage_based_billing_enabled: if is_staff {
2913            Some(true)
2914        } else {
2915            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2916        },
2917        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2918            proto::SubscriptionPeriod {
2919                started_at: started_at.timestamp() as u64,
2920                ended_at: ended_at.timestamp() as u64,
2921            }
2922        }),
2923        account_too_young: Some(account_too_young),
2924        has_overdue_invoices: billing_customer
2925            .map(|billing_customer| billing_customer.has_overdue_invoices),
2926        usage: Some(
2927            usage
2928                .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2929                .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2930        ),
2931    })
2932}
2933
2934fn model_requests_limit(
2935    plan: cloud_llm_client::Plan,
2936    feature_flags: &Vec<String>,
2937) -> cloud_llm_client::UsageLimit {
2938    match plan.model_requests_limit() {
2939        cloud_llm_client::UsageLimit::Limited(limit) => {
2940            let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2941                && feature_flags
2942                    .iter()
2943                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2944            {
2945                1_000
2946            } else {
2947                limit
2948            };
2949
2950            cloud_llm_client::UsageLimit::Limited(limit)
2951        }
2952        cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2953    }
2954}
2955
2956fn subscription_usage_to_proto(
2957    plan: proto::Plan,
2958    usage: crate::llm::db::subscription_usage::Model,
2959    feature_flags: &Vec<String>,
2960) -> proto::SubscriptionUsage {
2961    let plan = match plan {
2962        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2963        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2964        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2965    };
2966
2967    proto::SubscriptionUsage {
2968        model_requests_usage_amount: usage.model_requests as u32,
2969        model_requests_usage_limit: Some(proto::UsageLimit {
2970            variant: Some(match model_requests_limit(plan, feature_flags) {
2971                cloud_llm_client::UsageLimit::Limited(limit) => {
2972                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2973                        limit: limit as u32,
2974                    })
2975                }
2976                cloud_llm_client::UsageLimit::Unlimited => {
2977                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2978                }
2979            }),
2980        }),
2981        edit_predictions_usage_amount: usage.edit_predictions as u32,
2982        edit_predictions_usage_limit: Some(proto::UsageLimit {
2983            variant: Some(match plan.edit_predictions_limit() {
2984                cloud_llm_client::UsageLimit::Limited(limit) => {
2985                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2986                        limit: limit as u32,
2987                    })
2988                }
2989                cloud_llm_client::UsageLimit::Unlimited => {
2990                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2991                }
2992            }),
2993        }),
2994    }
2995}
2996
2997fn make_default_subscription_usage(
2998    plan: proto::Plan,
2999    feature_flags: &Vec<String>,
3000) -> proto::SubscriptionUsage {
3001    let plan = match plan {
3002        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
3003        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
3004        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
3005    };
3006
3007    proto::SubscriptionUsage {
3008        model_requests_usage_amount: 0,
3009        model_requests_usage_limit: Some(proto::UsageLimit {
3010            variant: Some(match model_requests_limit(plan, feature_flags) {
3011                cloud_llm_client::UsageLimit::Limited(limit) => {
3012                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3013                        limit: limit as u32,
3014                    })
3015                }
3016                cloud_llm_client::UsageLimit::Unlimited => {
3017                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3018                }
3019            }),
3020        }),
3021        edit_predictions_usage_amount: 0,
3022        edit_predictions_usage_limit: Some(proto::UsageLimit {
3023            variant: Some(match plan.edit_predictions_limit() {
3024                cloud_llm_client::UsageLimit::Limited(limit) => {
3025                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3026                        limit: limit as u32,
3027                    })
3028                }
3029                cloud_llm_client::UsageLimit::Unlimited => {
3030                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3031                }
3032            }),
3033        }),
3034    }
3035}
3036
3037async fn update_user_plan(session: &Session) -> Result<()> {
3038    let db = session.db().await;
3039
3040    let update_user_plan = make_update_user_plan_message(
3041        session.principal.user(),
3042        session.is_staff(),
3043        &db.0,
3044        session.app_state.llm_db.clone(),
3045    )
3046    .await?;
3047
3048    session
3049        .peer
3050        .send(session.connection_id, update_user_plan)
3051        .trace_err();
3052
3053    Ok(())
3054}
3055
3056async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3057    subscribe_user_to_channels(session.user_id(), &session).await?;
3058    Ok(())
3059}
3060
3061async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3062    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3063    let mut pool = session.connection_pool().await;
3064    for membership in &channels_for_user.channel_memberships {
3065        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3066    }
3067    session.peer.send(
3068        session.connection_id,
3069        build_update_user_channels(&channels_for_user),
3070    )?;
3071    session.peer.send(
3072        session.connection_id,
3073        build_channels_update(channels_for_user),
3074    )?;
3075    Ok(())
3076}
3077
3078/// Creates a new channel.
3079async fn create_channel(
3080    request: proto::CreateChannel,
3081    response: Response<proto::CreateChannel>,
3082    session: Session,
3083) -> Result<()> {
3084    let db = session.db().await;
3085
3086    let parent_id = request.parent_id.map(ChannelId::from_proto);
3087    let (channel, membership) = db
3088        .create_channel(&request.name, parent_id, session.user_id())
3089        .await?;
3090
3091    let root_id = channel.root_id();
3092    let channel = Channel::from_model(channel);
3093
3094    response.send(proto::CreateChannelResponse {
3095        channel: Some(channel.to_proto()),
3096        parent_id: request.parent_id,
3097    })?;
3098
3099    let mut connection_pool = session.connection_pool().await;
3100    if let Some(membership) = membership {
3101        connection_pool.subscribe_to_channel(
3102            membership.user_id,
3103            membership.channel_id,
3104            membership.role,
3105        );
3106        let update = proto::UpdateUserChannels {
3107            channel_memberships: vec![proto::ChannelMembership {
3108                channel_id: membership.channel_id.to_proto(),
3109                role: membership.role.into(),
3110            }],
3111            ..Default::default()
3112        };
3113        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3114            session.peer.send(connection_id, update.clone())?;
3115        }
3116    }
3117
3118    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3119        if !role.can_see_channel(channel.visibility) {
3120            continue;
3121        }
3122
3123        let update = proto::UpdateChannels {
3124            channels: vec![channel.to_proto()],
3125            ..Default::default()
3126        };
3127        session.peer.send(connection_id, update.clone())?;
3128    }
3129
3130    Ok(())
3131}
3132
3133/// Delete a channel
3134async fn delete_channel(
3135    request: proto::DeleteChannel,
3136    response: Response<proto::DeleteChannel>,
3137    session: Session,
3138) -> Result<()> {
3139    let db = session.db().await;
3140
3141    let channel_id = request.channel_id;
3142    let (root_channel, removed_channels) = db
3143        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3144        .await?;
3145    response.send(proto::Ack {})?;
3146
3147    // Notify members of removed channels
3148    let mut update = proto::UpdateChannels::default();
3149    update
3150        .delete_channels
3151        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3152
3153    let connection_pool = session.connection_pool().await;
3154    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3155        session.peer.send(connection_id, update.clone())?;
3156    }
3157
3158    Ok(())
3159}
3160
3161/// Invite someone to join a channel.
3162async fn invite_channel_member(
3163    request: proto::InviteChannelMember,
3164    response: Response<proto::InviteChannelMember>,
3165    session: Session,
3166) -> Result<()> {
3167    let db = session.db().await;
3168    let channel_id = ChannelId::from_proto(request.channel_id);
3169    let invitee_id = UserId::from_proto(request.user_id);
3170    let InviteMemberResult {
3171        channel,
3172        notifications,
3173    } = db
3174        .invite_channel_member(
3175            channel_id,
3176            invitee_id,
3177            session.user_id(),
3178            request.role().into(),
3179        )
3180        .await?;
3181
3182    let update = proto::UpdateChannels {
3183        channel_invitations: vec![channel.to_proto()],
3184        ..Default::default()
3185    };
3186
3187    let connection_pool = session.connection_pool().await;
3188    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3189        session.peer.send(connection_id, update.clone())?;
3190    }
3191
3192    send_notifications(&connection_pool, &session.peer, notifications);
3193
3194    response.send(proto::Ack {})?;
3195    Ok(())
3196}
3197
3198/// remove someone from a channel
3199async fn remove_channel_member(
3200    request: proto::RemoveChannelMember,
3201    response: Response<proto::RemoveChannelMember>,
3202    session: Session,
3203) -> Result<()> {
3204    let db = session.db().await;
3205    let channel_id = ChannelId::from_proto(request.channel_id);
3206    let member_id = UserId::from_proto(request.user_id);
3207
3208    let RemoveChannelMemberResult {
3209        membership_update,
3210        notification_id,
3211    } = db
3212        .remove_channel_member(channel_id, member_id, session.user_id())
3213        .await?;
3214
3215    let mut connection_pool = session.connection_pool().await;
3216    notify_membership_updated(
3217        &mut connection_pool,
3218        membership_update,
3219        member_id,
3220        &session.peer,
3221    );
3222    for connection_id in connection_pool.user_connection_ids(member_id) {
3223        if let Some(notification_id) = notification_id {
3224            session
3225                .peer
3226                .send(
3227                    connection_id,
3228                    proto::DeleteNotification {
3229                        notification_id: notification_id.to_proto(),
3230                    },
3231                )
3232                .trace_err();
3233        }
3234    }
3235
3236    response.send(proto::Ack {})?;
3237    Ok(())
3238}
3239
3240/// Toggle the channel between public and private.
3241/// Care is taken to maintain the invariant that public channels only descend from public channels,
3242/// (though members-only channels can appear at any point in the hierarchy).
3243async fn set_channel_visibility(
3244    request: proto::SetChannelVisibility,
3245    response: Response<proto::SetChannelVisibility>,
3246    session: Session,
3247) -> Result<()> {
3248    let db = session.db().await;
3249    let channel_id = ChannelId::from_proto(request.channel_id);
3250    let visibility = request.visibility().into();
3251
3252    let channel_model = db
3253        .set_channel_visibility(channel_id, visibility, session.user_id())
3254        .await?;
3255    let root_id = channel_model.root_id();
3256    let channel = Channel::from_model(channel_model);
3257
3258    let mut connection_pool = session.connection_pool().await;
3259    for (user_id, role) in connection_pool
3260        .channel_user_ids(root_id)
3261        .collect::<Vec<_>>()
3262        .into_iter()
3263    {
3264        let update = if role.can_see_channel(channel.visibility) {
3265            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3266            proto::UpdateChannels {
3267                channels: vec![channel.to_proto()],
3268                ..Default::default()
3269            }
3270        } else {
3271            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3272            proto::UpdateChannels {
3273                delete_channels: vec![channel.id.to_proto()],
3274                ..Default::default()
3275            }
3276        };
3277
3278        for connection_id in connection_pool.user_connection_ids(user_id) {
3279            session.peer.send(connection_id, update.clone())?;
3280        }
3281    }
3282
3283    response.send(proto::Ack {})?;
3284    Ok(())
3285}
3286
3287/// Alter the role for a user in the channel.
3288async fn set_channel_member_role(
3289    request: proto::SetChannelMemberRole,
3290    response: Response<proto::SetChannelMemberRole>,
3291    session: Session,
3292) -> Result<()> {
3293    let db = session.db().await;
3294    let channel_id = ChannelId::from_proto(request.channel_id);
3295    let member_id = UserId::from_proto(request.user_id);
3296    let result = db
3297        .set_channel_member_role(
3298            channel_id,
3299            session.user_id(),
3300            member_id,
3301            request.role().into(),
3302        )
3303        .await?;
3304
3305    match result {
3306        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3307            let mut connection_pool = session.connection_pool().await;
3308            notify_membership_updated(
3309                &mut connection_pool,
3310                membership_update,
3311                member_id,
3312                &session.peer,
3313            )
3314        }
3315        db::SetMemberRoleResult::InviteUpdated(channel) => {
3316            let update = proto::UpdateChannels {
3317                channel_invitations: vec![channel.to_proto()],
3318                ..Default::default()
3319            };
3320
3321            for connection_id in session
3322                .connection_pool()
3323                .await
3324                .user_connection_ids(member_id)
3325            {
3326                session.peer.send(connection_id, update.clone())?;
3327            }
3328        }
3329    }
3330
3331    response.send(proto::Ack {})?;
3332    Ok(())
3333}
3334
3335/// Change the name of a channel
3336async fn rename_channel(
3337    request: proto::RenameChannel,
3338    response: Response<proto::RenameChannel>,
3339    session: Session,
3340) -> Result<()> {
3341    let db = session.db().await;
3342    let channel_id = ChannelId::from_proto(request.channel_id);
3343    let channel_model = db
3344        .rename_channel(channel_id, session.user_id(), &request.name)
3345        .await?;
3346    let root_id = channel_model.root_id();
3347    let channel = Channel::from_model(channel_model);
3348
3349    response.send(proto::RenameChannelResponse {
3350        channel: Some(channel.to_proto()),
3351    })?;
3352
3353    let connection_pool = session.connection_pool().await;
3354    let update = proto::UpdateChannels {
3355        channels: vec![channel.to_proto()],
3356        ..Default::default()
3357    };
3358    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3359        if role.can_see_channel(channel.visibility) {
3360            session.peer.send(connection_id, update.clone())?;
3361        }
3362    }
3363
3364    Ok(())
3365}
3366
3367/// Move a channel to a new parent.
3368async fn move_channel(
3369    request: proto::MoveChannel,
3370    response: Response<proto::MoveChannel>,
3371    session: Session,
3372) -> Result<()> {
3373    let channel_id = ChannelId::from_proto(request.channel_id);
3374    let to = ChannelId::from_proto(request.to);
3375
3376    let (root_id, channels) = session
3377        .db()
3378        .await
3379        .move_channel(channel_id, to, session.user_id())
3380        .await?;
3381
3382    let connection_pool = session.connection_pool().await;
3383    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3384        let channels = channels
3385            .iter()
3386            .filter_map(|channel| {
3387                if role.can_see_channel(channel.visibility) {
3388                    Some(channel.to_proto())
3389                } else {
3390                    None
3391                }
3392            })
3393            .collect::<Vec<_>>();
3394        if channels.is_empty() {
3395            continue;
3396        }
3397
3398        let update = proto::UpdateChannels {
3399            channels,
3400            ..Default::default()
3401        };
3402
3403        session.peer.send(connection_id, update.clone())?;
3404    }
3405
3406    response.send(Ack {})?;
3407    Ok(())
3408}
3409
3410async fn reorder_channel(
3411    request: proto::ReorderChannel,
3412    response: Response<proto::ReorderChannel>,
3413    session: Session,
3414) -> Result<()> {
3415    let channel_id = ChannelId::from_proto(request.channel_id);
3416    let direction = request.direction();
3417
3418    let updated_channels = session
3419        .db()
3420        .await
3421        .reorder_channel(channel_id, direction, session.user_id())
3422        .await?;
3423
3424    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3425        let connection_pool = session.connection_pool().await;
3426        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3427            let channels = updated_channels
3428                .iter()
3429                .filter_map(|channel| {
3430                    if role.can_see_channel(channel.visibility) {
3431                        Some(channel.to_proto())
3432                    } else {
3433                        None
3434                    }
3435                })
3436                .collect::<Vec<_>>();
3437
3438            if channels.is_empty() {
3439                continue;
3440            }
3441
3442            let update = proto::UpdateChannels {
3443                channels,
3444                ..Default::default()
3445            };
3446
3447            session.peer.send(connection_id, update.clone())?;
3448        }
3449    }
3450
3451    response.send(Ack {})?;
3452    Ok(())
3453}
3454
3455/// Get the list of channel members
3456async fn get_channel_members(
3457    request: proto::GetChannelMembers,
3458    response: Response<proto::GetChannelMembers>,
3459    session: Session,
3460) -> Result<()> {
3461    let db = session.db().await;
3462    let channel_id = ChannelId::from_proto(request.channel_id);
3463    let limit = if request.limit == 0 {
3464        u16::MAX as u64
3465    } else {
3466        request.limit
3467    };
3468    let (members, users) = db
3469        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3470        .await?;
3471    response.send(proto::GetChannelMembersResponse { members, users })?;
3472    Ok(())
3473}
3474
3475/// Accept or decline a channel invitation.
3476async fn respond_to_channel_invite(
3477    request: proto::RespondToChannelInvite,
3478    response: Response<proto::RespondToChannelInvite>,
3479    session: Session,
3480) -> Result<()> {
3481    let db = session.db().await;
3482    let channel_id = ChannelId::from_proto(request.channel_id);
3483    let RespondToChannelInvite {
3484        membership_update,
3485        notifications,
3486    } = db
3487        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3488        .await?;
3489
3490    let mut connection_pool = session.connection_pool().await;
3491    if let Some(membership_update) = membership_update {
3492        notify_membership_updated(
3493            &mut connection_pool,
3494            membership_update,
3495            session.user_id(),
3496            &session.peer,
3497        );
3498    } else {
3499        let update = proto::UpdateChannels {
3500            remove_channel_invitations: vec![channel_id.to_proto()],
3501            ..Default::default()
3502        };
3503
3504        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3505            session.peer.send(connection_id, update.clone())?;
3506        }
3507    };
3508
3509    send_notifications(&connection_pool, &session.peer, notifications);
3510
3511    response.send(proto::Ack {})?;
3512
3513    Ok(())
3514}
3515
3516/// Join the channels' room
3517async fn join_channel(
3518    request: proto::JoinChannel,
3519    response: Response<proto::JoinChannel>,
3520    session: Session,
3521) -> Result<()> {
3522    let channel_id = ChannelId::from_proto(request.channel_id);
3523    join_channel_internal(channel_id, Box::new(response), session).await
3524}
3525
3526trait JoinChannelInternalResponse {
3527    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3528}
3529impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3530    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3531        Response::<proto::JoinChannel>::send(self, result)
3532    }
3533}
3534impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3535    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3536        Response::<proto::JoinRoom>::send(self, result)
3537    }
3538}
3539
3540async fn join_channel_internal(
3541    channel_id: ChannelId,
3542    response: Box<impl JoinChannelInternalResponse>,
3543    session: Session,
3544) -> Result<()> {
3545    let joined_room = {
3546        let mut db = session.db().await;
3547        // If zed quits without leaving the room, and the user re-opens zed before the
3548        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3549        // room they were in.
3550        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3551            tracing::info!(
3552                stale_connection_id = %connection,
3553                "cleaning up stale connection",
3554            );
3555            drop(db);
3556            leave_room_for_session(&session, connection).await?;
3557            db = session.db().await;
3558        }
3559
3560        let (joined_room, membership_updated, role) = db
3561            .join_channel(channel_id, session.user_id(), session.connection_id)
3562            .await?;
3563
3564        let live_kit_connection_info =
3565            session
3566                .app_state
3567                .livekit_client
3568                .as_ref()
3569                .and_then(|live_kit| {
3570                    let (can_publish, token) = if role == ChannelRole::Guest {
3571                        (
3572                            false,
3573                            live_kit
3574                                .guest_token(
3575                                    &joined_room.room.livekit_room,
3576                                    &session.user_id().to_string(),
3577                                )
3578                                .trace_err()?,
3579                        )
3580                    } else {
3581                        (
3582                            true,
3583                            live_kit
3584                                .room_token(
3585                                    &joined_room.room.livekit_room,
3586                                    &session.user_id().to_string(),
3587                                )
3588                                .trace_err()?,
3589                        )
3590                    };
3591
3592                    Some(LiveKitConnectionInfo {
3593                        server_url: live_kit.url().into(),
3594                        token,
3595                        can_publish,
3596                    })
3597                });
3598
3599        response.send(proto::JoinRoomResponse {
3600            room: Some(joined_room.room.clone()),
3601            channel_id: joined_room
3602                .channel
3603                .as_ref()
3604                .map(|channel| channel.id.to_proto()),
3605            live_kit_connection_info,
3606        })?;
3607
3608        let mut connection_pool = session.connection_pool().await;
3609        if let Some(membership_updated) = membership_updated {
3610            notify_membership_updated(
3611                &mut connection_pool,
3612                membership_updated,
3613                session.user_id(),
3614                &session.peer,
3615            );
3616        }
3617
3618        room_updated(&joined_room.room, &session.peer);
3619
3620        joined_room
3621    };
3622
3623    channel_updated(
3624        &joined_room.channel.context("channel not returned")?,
3625        &joined_room.room,
3626        &session.peer,
3627        &*session.connection_pool().await,
3628    );
3629
3630    update_user_contacts(session.user_id(), &session).await?;
3631    Ok(())
3632}
3633
3634/// Start editing the channel notes
3635async fn join_channel_buffer(
3636    request: proto::JoinChannelBuffer,
3637    response: Response<proto::JoinChannelBuffer>,
3638    session: Session,
3639) -> Result<()> {
3640    let db = session.db().await;
3641    let channel_id = ChannelId::from_proto(request.channel_id);
3642
3643    let open_response = db
3644        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3645        .await?;
3646
3647    let collaborators = open_response.collaborators.clone();
3648    response.send(open_response)?;
3649
3650    let update = UpdateChannelBufferCollaborators {
3651        channel_id: channel_id.to_proto(),
3652        collaborators: collaborators.clone(),
3653    };
3654    channel_buffer_updated(
3655        session.connection_id,
3656        collaborators
3657            .iter()
3658            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3659        &update,
3660        &session.peer,
3661    );
3662
3663    Ok(())
3664}
3665
3666/// Edit the channel notes
3667async fn update_channel_buffer(
3668    request: proto::UpdateChannelBuffer,
3669    session: Session,
3670) -> Result<()> {
3671    let db = session.db().await;
3672    let channel_id = ChannelId::from_proto(request.channel_id);
3673
3674    let (collaborators, epoch, version) = db
3675        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3676        .await?;
3677
3678    channel_buffer_updated(
3679        session.connection_id,
3680        collaborators.clone(),
3681        &proto::UpdateChannelBuffer {
3682            channel_id: channel_id.to_proto(),
3683            operations: request.operations,
3684        },
3685        &session.peer,
3686    );
3687
3688    let pool = &*session.connection_pool().await;
3689
3690    let non_collaborators =
3691        pool.channel_connection_ids(channel_id)
3692            .filter_map(|(connection_id, _)| {
3693                if collaborators.contains(&connection_id) {
3694                    None
3695                } else {
3696                    Some(connection_id)
3697                }
3698            });
3699
3700    broadcast(None, non_collaborators, |peer_id| {
3701        session.peer.send(
3702            peer_id,
3703            proto::UpdateChannels {
3704                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3705                    channel_id: channel_id.to_proto(),
3706                    epoch: epoch as u64,
3707                    version: version.clone(),
3708                }],
3709                ..Default::default()
3710            },
3711        )
3712    });
3713
3714    Ok(())
3715}
3716
3717/// Rejoin the channel notes after a connection blip
3718async fn rejoin_channel_buffers(
3719    request: proto::RejoinChannelBuffers,
3720    response: Response<proto::RejoinChannelBuffers>,
3721    session: Session,
3722) -> Result<()> {
3723    let db = session.db().await;
3724    let buffers = db
3725        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3726        .await?;
3727
3728    for rejoined_buffer in &buffers {
3729        let collaborators_to_notify = rejoined_buffer
3730            .buffer
3731            .collaborators
3732            .iter()
3733            .filter_map(|c| Some(c.peer_id?.into()));
3734        channel_buffer_updated(
3735            session.connection_id,
3736            collaborators_to_notify,
3737            &proto::UpdateChannelBufferCollaborators {
3738                channel_id: rejoined_buffer.buffer.channel_id,
3739                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3740            },
3741            &session.peer,
3742        );
3743    }
3744
3745    response.send(proto::RejoinChannelBuffersResponse {
3746        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3747    })?;
3748
3749    Ok(())
3750}
3751
3752/// Stop editing the channel notes
3753async fn leave_channel_buffer(
3754    request: proto::LeaveChannelBuffer,
3755    response: Response<proto::LeaveChannelBuffer>,
3756    session: Session,
3757) -> Result<()> {
3758    let db = session.db().await;
3759    let channel_id = ChannelId::from_proto(request.channel_id);
3760
3761    let left_buffer = db
3762        .leave_channel_buffer(channel_id, session.connection_id)
3763        .await?;
3764
3765    response.send(Ack {})?;
3766
3767    channel_buffer_updated(
3768        session.connection_id,
3769        left_buffer.connections,
3770        &proto::UpdateChannelBufferCollaborators {
3771            channel_id: channel_id.to_proto(),
3772            collaborators: left_buffer.collaborators,
3773        },
3774        &session.peer,
3775    );
3776
3777    Ok(())
3778}
3779
3780fn channel_buffer_updated<T: EnvelopedMessage>(
3781    sender_id: ConnectionId,
3782    collaborators: impl IntoIterator<Item = ConnectionId>,
3783    message: &T,
3784    peer: &Peer,
3785) {
3786    broadcast(Some(sender_id), collaborators, |peer_id| {
3787        peer.send(peer_id, message.clone())
3788    });
3789}
3790
3791fn send_notifications(
3792    connection_pool: &ConnectionPool,
3793    peer: &Peer,
3794    notifications: db::NotificationBatch,
3795) {
3796    for (user_id, notification) in notifications {
3797        for connection_id in connection_pool.user_connection_ids(user_id) {
3798            if let Err(error) = peer.send(
3799                connection_id,
3800                proto::AddNotification {
3801                    notification: Some(notification.clone()),
3802                },
3803            ) {
3804                tracing::error!(
3805                    "failed to send notification to {:?} {}",
3806                    connection_id,
3807                    error
3808                );
3809            }
3810        }
3811    }
3812}
3813
3814/// Send a message to the channel
3815async fn send_channel_message(
3816    request: proto::SendChannelMessage,
3817    response: Response<proto::SendChannelMessage>,
3818    session: Session,
3819) -> Result<()> {
3820    // Validate the message body.
3821    let body = request.body.trim().to_string();
3822    if body.len() > MAX_MESSAGE_LEN {
3823        return Err(anyhow!("message is too long"))?;
3824    }
3825    if body.is_empty() {
3826        return Err(anyhow!("message can't be blank"))?;
3827    }
3828
3829    // TODO: adjust mentions if body is trimmed
3830
3831    let timestamp = OffsetDateTime::now_utc();
3832    let nonce = request.nonce.context("nonce can't be blank")?;
3833
3834    let channel_id = ChannelId::from_proto(request.channel_id);
3835    let CreatedChannelMessage {
3836        message_id,
3837        participant_connection_ids,
3838        notifications,
3839    } = session
3840        .db()
3841        .await
3842        .create_channel_message(
3843            channel_id,
3844            session.user_id(),
3845            &body,
3846            &request.mentions,
3847            timestamp,
3848            nonce.clone().into(),
3849            request.reply_to_message_id.map(MessageId::from_proto),
3850        )
3851        .await?;
3852
3853    let message = proto::ChannelMessage {
3854        sender_id: session.user_id().to_proto(),
3855        id: message_id.to_proto(),
3856        body,
3857        mentions: request.mentions,
3858        timestamp: timestamp.unix_timestamp() as u64,
3859        nonce: Some(nonce),
3860        reply_to_message_id: request.reply_to_message_id,
3861        edited_at: None,
3862    };
3863    broadcast(
3864        Some(session.connection_id),
3865        participant_connection_ids.clone(),
3866        |connection| {
3867            session.peer.send(
3868                connection,
3869                proto::ChannelMessageSent {
3870                    channel_id: channel_id.to_proto(),
3871                    message: Some(message.clone()),
3872                },
3873            )
3874        },
3875    );
3876    response.send(proto::SendChannelMessageResponse {
3877        message: Some(message),
3878    })?;
3879
3880    let pool = &*session.connection_pool().await;
3881    let non_participants =
3882        pool.channel_connection_ids(channel_id)
3883            .filter_map(|(connection_id, _)| {
3884                if participant_connection_ids.contains(&connection_id) {
3885                    None
3886                } else {
3887                    Some(connection_id)
3888                }
3889            });
3890    broadcast(None, non_participants, |peer_id| {
3891        session.peer.send(
3892            peer_id,
3893            proto::UpdateChannels {
3894                latest_channel_message_ids: vec![proto::ChannelMessageId {
3895                    channel_id: channel_id.to_proto(),
3896                    message_id: message_id.to_proto(),
3897                }],
3898                ..Default::default()
3899            },
3900        )
3901    });
3902    send_notifications(pool, &session.peer, notifications);
3903
3904    Ok(())
3905}
3906
3907/// Delete a channel message
3908async fn remove_channel_message(
3909    request: proto::RemoveChannelMessage,
3910    response: Response<proto::RemoveChannelMessage>,
3911    session: Session,
3912) -> Result<()> {
3913    let channel_id = ChannelId::from_proto(request.channel_id);
3914    let message_id = MessageId::from_proto(request.message_id);
3915    let (connection_ids, existing_notification_ids) = session
3916        .db()
3917        .await
3918        .remove_channel_message(channel_id, message_id, session.user_id())
3919        .await?;
3920
3921    broadcast(
3922        Some(session.connection_id),
3923        connection_ids,
3924        move |connection| {
3925            session.peer.send(connection, request.clone())?;
3926
3927            for notification_id in &existing_notification_ids {
3928                session.peer.send(
3929                    connection,
3930                    proto::DeleteNotification {
3931                        notification_id: (*notification_id).to_proto(),
3932                    },
3933                )?;
3934            }
3935
3936            Ok(())
3937        },
3938    );
3939    response.send(proto::Ack {})?;
3940    Ok(())
3941}
3942
3943async fn update_channel_message(
3944    request: proto::UpdateChannelMessage,
3945    response: Response<proto::UpdateChannelMessage>,
3946    session: Session,
3947) -> Result<()> {
3948    let channel_id = ChannelId::from_proto(request.channel_id);
3949    let message_id = MessageId::from_proto(request.message_id);
3950    let updated_at = OffsetDateTime::now_utc();
3951    let UpdatedChannelMessage {
3952        message_id,
3953        participant_connection_ids,
3954        notifications,
3955        reply_to_message_id,
3956        timestamp,
3957        deleted_mention_notification_ids,
3958        updated_mention_notifications,
3959    } = session
3960        .db()
3961        .await
3962        .update_channel_message(
3963            channel_id,
3964            message_id,
3965            session.user_id(),
3966            request.body.as_str(),
3967            &request.mentions,
3968            updated_at,
3969        )
3970        .await?;
3971
3972    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3973
3974    let message = proto::ChannelMessage {
3975        sender_id: session.user_id().to_proto(),
3976        id: message_id.to_proto(),
3977        body: request.body.clone(),
3978        mentions: request.mentions.clone(),
3979        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3980        nonce: Some(nonce),
3981        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3982        edited_at: Some(updated_at.unix_timestamp() as u64),
3983    };
3984
3985    response.send(proto::Ack {})?;
3986
3987    let pool = &*session.connection_pool().await;
3988    broadcast(
3989        Some(session.connection_id),
3990        participant_connection_ids,
3991        |connection| {
3992            session.peer.send(
3993                connection,
3994                proto::ChannelMessageUpdate {
3995                    channel_id: channel_id.to_proto(),
3996                    message: Some(message.clone()),
3997                },
3998            )?;
3999
4000            for notification_id in &deleted_mention_notification_ids {
4001                session.peer.send(
4002                    connection,
4003                    proto::DeleteNotification {
4004                        notification_id: (*notification_id).to_proto(),
4005                    },
4006                )?;
4007            }
4008
4009            for notification in &updated_mention_notifications {
4010                session.peer.send(
4011                    connection,
4012                    proto::UpdateNotification {
4013                        notification: Some(notification.clone()),
4014                    },
4015                )?;
4016            }
4017
4018            Ok(())
4019        },
4020    );
4021
4022    send_notifications(pool, &session.peer, notifications);
4023
4024    Ok(())
4025}
4026
4027/// Mark a channel message as read
4028async fn acknowledge_channel_message(
4029    request: proto::AckChannelMessage,
4030    session: Session,
4031) -> Result<()> {
4032    let channel_id = ChannelId::from_proto(request.channel_id);
4033    let message_id = MessageId::from_proto(request.message_id);
4034    let notifications = session
4035        .db()
4036        .await
4037        .observe_channel_message(channel_id, session.user_id(), message_id)
4038        .await?;
4039    send_notifications(
4040        &*session.connection_pool().await,
4041        &session.peer,
4042        notifications,
4043    );
4044    Ok(())
4045}
4046
4047/// Mark a buffer version as synced
4048async fn acknowledge_buffer_version(
4049    request: proto::AckBufferOperation,
4050    session: Session,
4051) -> Result<()> {
4052    let buffer_id = BufferId::from_proto(request.buffer_id);
4053    session
4054        .db()
4055        .await
4056        .observe_buffer_version(
4057            buffer_id,
4058            session.user_id(),
4059            request.epoch as i32,
4060            &request.version,
4061        )
4062        .await?;
4063    Ok(())
4064}
4065
4066/// Get a Supermaven API key for the user
4067async fn get_supermaven_api_key(
4068    _request: proto::GetSupermavenApiKey,
4069    response: Response<proto::GetSupermavenApiKey>,
4070    session: Session,
4071) -> Result<()> {
4072    let user_id: String = session.user_id().to_string();
4073    if !session.is_staff() {
4074        return Err(anyhow!("supermaven not enabled for this account"))?;
4075    }
4076
4077    let email = session.email().context("user must have an email")?;
4078
4079    let supermaven_admin_api = session
4080        .supermaven_client
4081        .as_ref()
4082        .context("supermaven not configured")?;
4083
4084    let result = supermaven_admin_api
4085        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4086        .await?;
4087
4088    response.send(proto::GetSupermavenApiKeyResponse {
4089        api_key: result.api_key,
4090    })?;
4091
4092    Ok(())
4093}
4094
4095/// Start receiving chat updates for a channel
4096async fn join_channel_chat(
4097    request: proto::JoinChannelChat,
4098    response: Response<proto::JoinChannelChat>,
4099    session: Session,
4100) -> Result<()> {
4101    let channel_id = ChannelId::from_proto(request.channel_id);
4102
4103    let db = session.db().await;
4104    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4105        .await?;
4106    let messages = db
4107        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4108        .await?;
4109    response.send(proto::JoinChannelChatResponse {
4110        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4111        messages,
4112    })?;
4113    Ok(())
4114}
4115
4116/// Stop receiving chat updates for a channel
4117async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4118    let channel_id = ChannelId::from_proto(request.channel_id);
4119    session
4120        .db()
4121        .await
4122        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4123        .await?;
4124    Ok(())
4125}
4126
4127/// Retrieve the chat history for a channel
4128async fn get_channel_messages(
4129    request: proto::GetChannelMessages,
4130    response: Response<proto::GetChannelMessages>,
4131    session: Session,
4132) -> Result<()> {
4133    let channel_id = ChannelId::from_proto(request.channel_id);
4134    let messages = session
4135        .db()
4136        .await
4137        .get_channel_messages(
4138            channel_id,
4139            session.user_id(),
4140            MESSAGE_COUNT_PER_PAGE,
4141            Some(MessageId::from_proto(request.before_message_id)),
4142        )
4143        .await?;
4144    response.send(proto::GetChannelMessagesResponse {
4145        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4146        messages,
4147    })?;
4148    Ok(())
4149}
4150
4151/// Retrieve specific chat messages
4152async fn get_channel_messages_by_id(
4153    request: proto::GetChannelMessagesById,
4154    response: Response<proto::GetChannelMessagesById>,
4155    session: Session,
4156) -> Result<()> {
4157    let message_ids = request
4158        .message_ids
4159        .iter()
4160        .map(|id| MessageId::from_proto(*id))
4161        .collect::<Vec<_>>();
4162    let messages = session
4163        .db()
4164        .await
4165        .get_channel_messages_by_id(session.user_id(), &message_ids)
4166        .await?;
4167    response.send(proto::GetChannelMessagesResponse {
4168        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4169        messages,
4170    })?;
4171    Ok(())
4172}
4173
4174/// Retrieve the current users notifications
4175async fn get_notifications(
4176    request: proto::GetNotifications,
4177    response: Response<proto::GetNotifications>,
4178    session: Session,
4179) -> Result<()> {
4180    let notifications = session
4181        .db()
4182        .await
4183        .get_notifications(
4184            session.user_id(),
4185            NOTIFICATION_COUNT_PER_PAGE,
4186            request.before_id.map(db::NotificationId::from_proto),
4187        )
4188        .await?;
4189    response.send(proto::GetNotificationsResponse {
4190        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4191        notifications,
4192    })?;
4193    Ok(())
4194}
4195
4196/// Mark notifications as read
4197async fn mark_notification_as_read(
4198    request: proto::MarkNotificationRead,
4199    response: Response<proto::MarkNotificationRead>,
4200    session: Session,
4201) -> Result<()> {
4202    let database = &session.db().await;
4203    let notifications = database
4204        .mark_notification_as_read_by_id(
4205            session.user_id(),
4206            NotificationId::from_proto(request.notification_id),
4207        )
4208        .await?;
4209    send_notifications(
4210        &*session.connection_pool().await,
4211        &session.peer,
4212        notifications,
4213    );
4214    response.send(proto::Ack {})?;
4215    Ok(())
4216}
4217
4218/// Get the current users information
4219async fn get_private_user_info(
4220    _request: proto::GetPrivateUserInfo,
4221    response: Response<proto::GetPrivateUserInfo>,
4222    session: Session,
4223) -> Result<()> {
4224    let db = session.db().await;
4225
4226    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4227    let user = db
4228        .get_user_by_id(session.user_id())
4229        .await?
4230        .context("user not found")?;
4231    let flags = db.get_user_flags(session.user_id()).await?;
4232
4233    response.send(proto::GetPrivateUserInfoResponse {
4234        metrics_id,
4235        staff: user.admin,
4236        flags,
4237        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4238    })?;
4239    Ok(())
4240}
4241
4242/// Accept the terms of service (tos) on behalf of the current user
4243async fn accept_terms_of_service(
4244    _request: proto::AcceptTermsOfService,
4245    response: Response<proto::AcceptTermsOfService>,
4246    session: Session,
4247) -> Result<()> {
4248    let db = session.db().await;
4249
4250    let accepted_tos_at = Utc::now();
4251    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4252        .await?;
4253
4254    response.send(proto::AcceptTermsOfServiceResponse {
4255        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4256    })?;
4257
4258    // When the user accepts the terms of service, we want to refresh their LLM
4259    // token to grant access.
4260    session
4261        .peer
4262        .send(session.connection_id, proto::RefreshLlmToken {})?;
4263
4264    Ok(())
4265}
4266
4267async fn get_llm_api_token(
4268    _request: proto::GetLlmToken,
4269    response: Response<proto::GetLlmToken>,
4270    session: Session,
4271) -> Result<()> {
4272    let db = session.db().await;
4273
4274    let flags = db.get_user_flags(session.user_id()).await?;
4275
4276    let user_id = session.user_id();
4277    let user = db
4278        .get_user_by_id(user_id)
4279        .await?
4280        .with_context(|| format!("user {user_id} not found"))?;
4281
4282    if user.accepted_tos_at.is_none() {
4283        Err(anyhow!("terms of service not accepted"))?
4284    }
4285
4286    let stripe_client = session
4287        .app_state
4288        .stripe_client
4289        .as_ref()
4290        .context("failed to retrieve Stripe client")?;
4291
4292    let stripe_billing = session
4293        .app_state
4294        .stripe_billing
4295        .as_ref()
4296        .context("failed to retrieve Stripe billing object")?;
4297
4298    let billing_customer = if let Some(billing_customer) =
4299        db.get_billing_customer_by_user_id(user.id).await?
4300    {
4301        billing_customer
4302    } else {
4303        let customer_id = stripe_billing
4304            .find_or_create_customer_by_email(user.email_address.as_deref())
4305            .await?;
4306
4307        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4308            .await?
4309            .context("billing customer not found")?
4310    };
4311
4312    let billing_subscription =
4313        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4314            billing_subscription
4315        } else {
4316            let stripe_customer_id =
4317                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4318
4319            let stripe_subscription = stripe_billing
4320                .subscribe_to_zed_free(stripe_customer_id)
4321                .await?;
4322
4323            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4324                billing_customer_id: billing_customer.id,
4325                kind: Some(SubscriptionKind::ZedFree),
4326                stripe_subscription_id: stripe_subscription.id.to_string(),
4327                stripe_subscription_status: stripe_subscription.status.into(),
4328                stripe_cancellation_reason: None,
4329                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4330                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4331            })
4332            .await?
4333        };
4334
4335    let billing_preferences = db.get_billing_preferences(user.id).await?;
4336
4337    let token = LlmTokenClaims::create(
4338        &user,
4339        session.is_staff(),
4340        billing_customer,
4341        billing_preferences,
4342        &flags,
4343        billing_subscription,
4344        session.system_id.clone(),
4345        &session.app_state.config,
4346    )?;
4347    response.send(proto::GetLlmTokenResponse { token })?;
4348    Ok(())
4349}
4350
4351fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4352    let message = match message {
4353        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4354        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4355        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4356        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4357        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4358            code: frame.code.into(),
4359            reason: frame.reason.as_str().to_owned().into(),
4360        })),
4361        // We should never receive a frame while reading the message, according
4362        // to the `tungstenite` maintainers:
4363        //
4364        // > It cannot occur when you read messages from the WebSocket, but it
4365        // > can be used when you want to send the raw frames (e.g. you want to
4366        // > send the frames to the WebSocket without composing the full message first).
4367        // >
4368        // > — https://github.com/snapview/tungstenite-rs/issues/268
4369        TungsteniteMessage::Frame(_) => {
4370            bail!("received an unexpected frame while reading the message")
4371        }
4372    };
4373
4374    Ok(message)
4375}
4376
4377fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4378    match message {
4379        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4380        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4381        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4382        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4383        AxumMessage::Close(frame) => {
4384            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4385                code: frame.code.into(),
4386                reason: frame.reason.as_ref().into(),
4387            }))
4388        }
4389    }
4390}
4391
4392fn notify_membership_updated(
4393    connection_pool: &mut ConnectionPool,
4394    result: MembershipUpdated,
4395    user_id: UserId,
4396    peer: &Peer,
4397) {
4398    for membership in &result.new_channels.channel_memberships {
4399        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4400    }
4401    for channel_id in &result.removed_channels {
4402        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4403    }
4404
4405    let user_channels_update = proto::UpdateUserChannels {
4406        channel_memberships: result
4407            .new_channels
4408            .channel_memberships
4409            .iter()
4410            .map(|cm| proto::ChannelMembership {
4411                channel_id: cm.channel_id.to_proto(),
4412                role: cm.role.into(),
4413            })
4414            .collect(),
4415        ..Default::default()
4416    };
4417
4418    let mut update = build_channels_update(result.new_channels);
4419    update.delete_channels = result
4420        .removed_channels
4421        .into_iter()
4422        .map(|id| id.to_proto())
4423        .collect();
4424    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4425
4426    for connection_id in connection_pool.user_connection_ids(user_id) {
4427        peer.send(connection_id, user_channels_update.clone())
4428            .trace_err();
4429        peer.send(connection_id, update.clone()).trace_err();
4430    }
4431}
4432
4433fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4434    proto::UpdateUserChannels {
4435        channel_memberships: channels
4436            .channel_memberships
4437            .iter()
4438            .map(|m| proto::ChannelMembership {
4439                channel_id: m.channel_id.to_proto(),
4440                role: m.role.into(),
4441            })
4442            .collect(),
4443        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4444        observed_channel_message_id: channels.observed_channel_messages.clone(),
4445    }
4446}
4447
4448fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4449    let mut update = proto::UpdateChannels::default();
4450
4451    for channel in channels.channels {
4452        update.channels.push(channel.to_proto());
4453    }
4454
4455    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4456    update.latest_channel_message_ids = channels.latest_channel_messages;
4457
4458    for (channel_id, participants) in channels.channel_participants {
4459        update
4460            .channel_participants
4461            .push(proto::ChannelParticipants {
4462                channel_id: channel_id.to_proto(),
4463                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4464            });
4465    }
4466
4467    for channel in channels.invited_channels {
4468        update.channel_invitations.push(channel.to_proto());
4469    }
4470
4471    update
4472}
4473
4474fn build_initial_contacts_update(
4475    contacts: Vec<db::Contact>,
4476    pool: &ConnectionPool,
4477) -> proto::UpdateContacts {
4478    let mut update = proto::UpdateContacts::default();
4479
4480    for contact in contacts {
4481        match contact {
4482            db::Contact::Accepted { user_id, busy } => {
4483                update.contacts.push(contact_for_user(user_id, busy, pool));
4484            }
4485            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4486            db::Contact::Incoming { user_id } => {
4487                update
4488                    .incoming_requests
4489                    .push(proto::IncomingContactRequest {
4490                        requester_id: user_id.to_proto(),
4491                    })
4492            }
4493        }
4494    }
4495
4496    update
4497}
4498
4499fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4500    proto::Contact {
4501        user_id: user_id.to_proto(),
4502        online: pool.is_user_online(user_id),
4503        busy,
4504    }
4505}
4506
4507fn room_updated(room: &proto::Room, peer: &Peer) {
4508    broadcast(
4509        None,
4510        room.participants
4511            .iter()
4512            .filter_map(|participant| Some(participant.peer_id?.into())),
4513        |peer_id| {
4514            peer.send(
4515                peer_id,
4516                proto::RoomUpdated {
4517                    room: Some(room.clone()),
4518                },
4519            )
4520        },
4521    );
4522}
4523
4524fn channel_updated(
4525    channel: &db::channel::Model,
4526    room: &proto::Room,
4527    peer: &Peer,
4528    pool: &ConnectionPool,
4529) {
4530    let participants = room
4531        .participants
4532        .iter()
4533        .map(|p| p.user_id)
4534        .collect::<Vec<_>>();
4535
4536    broadcast(
4537        None,
4538        pool.channel_connection_ids(channel.root_id())
4539            .filter_map(|(channel_id, role)| {
4540                role.can_see_channel(channel.visibility)
4541                    .then_some(channel_id)
4542            }),
4543        |peer_id| {
4544            peer.send(
4545                peer_id,
4546                proto::UpdateChannels {
4547                    channel_participants: vec![proto::ChannelParticipants {
4548                        channel_id: channel.id.to_proto(),
4549                        participant_user_ids: participants.clone(),
4550                    }],
4551                    ..Default::default()
4552                },
4553            )
4554        },
4555    );
4556}
4557
4558async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4559    let db = session.db().await;
4560
4561    let contacts = db.get_contacts(user_id).await?;
4562    let busy = db.is_user_busy(user_id).await?;
4563
4564    let pool = session.connection_pool().await;
4565    let updated_contact = contact_for_user(user_id, busy, &pool);
4566    for contact in contacts {
4567        if let db::Contact::Accepted {
4568            user_id: contact_user_id,
4569            ..
4570        } = contact
4571        {
4572            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4573                session
4574                    .peer
4575                    .send(
4576                        contact_conn_id,
4577                        proto::UpdateContacts {
4578                            contacts: vec![updated_contact.clone()],
4579                            remove_contacts: Default::default(),
4580                            incoming_requests: Default::default(),
4581                            remove_incoming_requests: Default::default(),
4582                            outgoing_requests: Default::default(),
4583                            remove_outgoing_requests: Default::default(),
4584                        },
4585                    )
4586                    .trace_err();
4587            }
4588        }
4589    }
4590    Ok(())
4591}
4592
4593async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4594    let mut contacts_to_update = HashSet::default();
4595
4596    let room_id;
4597    let canceled_calls_to_user_ids;
4598    let livekit_room;
4599    let delete_livekit_room;
4600    let room;
4601    let channel;
4602
4603    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4604        contacts_to_update.insert(session.user_id());
4605
4606        for project in left_room.left_projects.values() {
4607            project_left(project, session);
4608        }
4609
4610        room_id = RoomId::from_proto(left_room.room.id);
4611        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4612        livekit_room = mem::take(&mut left_room.room.livekit_room);
4613        delete_livekit_room = left_room.deleted;
4614        room = mem::take(&mut left_room.room);
4615        channel = mem::take(&mut left_room.channel);
4616
4617        room_updated(&room, &session.peer);
4618    } else {
4619        return Ok(());
4620    }
4621
4622    if let Some(channel) = channel {
4623        channel_updated(
4624            &channel,
4625            &room,
4626            &session.peer,
4627            &*session.connection_pool().await,
4628        );
4629    }
4630
4631    {
4632        let pool = session.connection_pool().await;
4633        for canceled_user_id in canceled_calls_to_user_ids {
4634            for connection_id in pool.user_connection_ids(canceled_user_id) {
4635                session
4636                    .peer
4637                    .send(
4638                        connection_id,
4639                        proto::CallCanceled {
4640                            room_id: room_id.to_proto(),
4641                        },
4642                    )
4643                    .trace_err();
4644            }
4645            contacts_to_update.insert(canceled_user_id);
4646        }
4647    }
4648
4649    for contact_user_id in contacts_to_update {
4650        update_user_contacts(contact_user_id, session).await?;
4651    }
4652
4653    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4654        live_kit
4655            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4656            .await
4657            .trace_err();
4658
4659        if delete_livekit_room {
4660            live_kit.delete_room(livekit_room).await.trace_err();
4661        }
4662    }
4663
4664    Ok(())
4665}
4666
4667async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4668    let left_channel_buffers = session
4669        .db()
4670        .await
4671        .leave_channel_buffers(session.connection_id)
4672        .await?;
4673
4674    for left_buffer in left_channel_buffers {
4675        channel_buffer_updated(
4676            session.connection_id,
4677            left_buffer.connections,
4678            &proto::UpdateChannelBufferCollaborators {
4679                channel_id: left_buffer.channel_id.to_proto(),
4680                collaborators: left_buffer.collaborators,
4681            },
4682            &session.peer,
4683        );
4684    }
4685
4686    Ok(())
4687}
4688
4689fn project_left(project: &db::LeftProject, session: &Session) {
4690    for connection_id in &project.connection_ids {
4691        if project.should_unshare {
4692            session
4693                .peer
4694                .send(
4695                    *connection_id,
4696                    proto::UnshareProject {
4697                        project_id: project.id.to_proto(),
4698                    },
4699                )
4700                .trace_err();
4701        } else {
4702            session
4703                .peer
4704                .send(
4705                    *connection_id,
4706                    proto::RemoveProjectCollaborator {
4707                        project_id: project.id.to_proto(),
4708                        peer_id: Some(session.connection_id.into()),
4709                    },
4710                )
4711                .trace_err();
4712        }
4713    }
4714}
4715
4716pub trait ResultExt {
4717    type Ok;
4718
4719    fn trace_err(self) -> Option<Self::Ok>;
4720}
4721
4722impl<T, E> ResultExt for Result<T, E>
4723where
4724    E: std::fmt::Debug,
4725{
4726    type Ok = T;
4727
4728    #[track_caller]
4729    fn trace_err(self) -> Option<T> {
4730        match self {
4731            Ok(value) => Some(value),
4732            Err(error) => {
4733                tracing::error!("{:?}", error);
4734                None
4735            }
4736        }
4737    }
4738}