rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{
   8    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
   9    MIN_ACCOUNT_AGE_FOR_LLM_USE,
  10};
  11use crate::stripe_client::StripeCustomerId;
  12use crate::{
  13    AppState, Error, Result, auth,
  14    db::{
  15        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  16        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  17        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  18        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  19    },
  20    executor::Executor,
  21};
  22use anyhow::{Context as _, anyhow, bail};
  23use async_tungstenite::tungstenite::{
  24    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  25};
  26use axum::headers::UserAgent;
  27use axum::{
  28    Extension, Router, TypedHeader,
  29    body::Body,
  30    extract::{
  31        ConnectInfo, WebSocketUpgrade,
  32        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  33    },
  34    headers::{Header, HeaderName},
  35    http::StatusCode,
  36    middleware,
  37    response::IntoResponse,
  38    routing::get,
  39};
  40use chrono::Utc;
  41use collections::{HashMap, HashSet};
  42pub use connection_pool::{ConnectionPool, ZedVersion};
  43use core::fmt::{self, Debug, Formatter};
  44use reqwest_client::ReqwestClient;
  45use rpc::proto::{MultiLspQuery, split_repository_update};
  46use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  47
  48use futures::{
  49    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  50    stream::FuturesUnordered,
  51};
  52use prometheus::{IntGauge, register_int_gauge};
  53use rpc::{
  54    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  55    proto::{
  56        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  57        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  58    },
  59};
  60use semantic_version::SemanticVersion;
  61use serde::{Serialize, Serializer};
  62use std::{
  63    any::TypeId,
  64    future::Future,
  65    marker::PhantomData,
  66    mem,
  67    net::SocketAddr,
  68    ops::{Deref, DerefMut},
  69    rc::Rc,
  70    sync::{
  71        Arc, OnceLock,
  72        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  73    },
  74    time::{Duration, Instant},
  75};
  76use time::OffsetDateTime;
  77use tokio::sync::{Semaphore, watch};
  78use tower::ServiceBuilder;
  79use tracing::{
  80    Instrument,
  81    field::{self},
  82    info_span, instrument,
  83};
  84
  85pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  86
  87// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  88pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  89
  90const MESSAGE_COUNT_PER_PAGE: usize = 100;
  91const MAX_MESSAGE_LEN: usize = 1024;
  92const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  93const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  94
  95static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  96
  97type MessageHandler =
  98    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  99
 100pub struct ConnectionGuard;
 101
 102impl ConnectionGuard {
 103    pub fn try_acquire() -> Result<Self, ()> {
 104        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 105        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 106            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 107            tracing::error!(
 108                "too many concurrent connections: {}",
 109                current_connections + 1
 110            );
 111            return Err(());
 112        }
 113        Ok(ConnectionGuard)
 114    }
 115}
 116
 117impl Drop for ConnectionGuard {
 118    fn drop(&mut self) {
 119        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 120    }
 121}
 122
 123struct Response<R> {
 124    peer: Arc<Peer>,
 125    receipt: Receipt<R>,
 126    responded: Arc<AtomicBool>,
 127}
 128
 129impl<R: RequestMessage> Response<R> {
 130    fn send(self, payload: R::Response) -> Result<()> {
 131        self.responded.store(true, SeqCst);
 132        self.peer.respond(self.receipt, payload)?;
 133        Ok(())
 134    }
 135}
 136
 137#[derive(Clone, Debug)]
 138pub enum Principal {
 139    User(User),
 140    Impersonated { user: User, admin: User },
 141}
 142
 143impl Principal {
 144    fn user(&self) -> &User {
 145        match self {
 146            Principal::User(user) => user,
 147            Principal::Impersonated { user, .. } => user,
 148        }
 149    }
 150
 151    fn update_span(&self, span: &tracing::Span) {
 152        match &self {
 153            Principal::User(user) => {
 154                span.record("user_id", user.id.0);
 155                span.record("login", &user.github_login);
 156            }
 157            Principal::Impersonated { user, admin } => {
 158                span.record("user_id", user.id.0);
 159                span.record("login", &user.github_login);
 160                span.record("impersonator", &admin.github_login);
 161            }
 162        }
 163    }
 164}
 165
 166#[derive(Clone)]
 167struct Session {
 168    principal: Principal,
 169    connection_id: ConnectionId,
 170    db: Arc<tokio::sync::Mutex<DbHandle>>,
 171    peer: Arc<Peer>,
 172    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 173    app_state: Arc<AppState>,
 174    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 175    /// The GeoIP country code for the user.
 176    #[allow(unused)]
 177    geoip_country_code: Option<String>,
 178    system_id: Option<String>,
 179    _executor: Executor,
 180}
 181
 182impl Session {
 183    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 184        #[cfg(test)]
 185        tokio::task::yield_now().await;
 186        let guard = self.db.lock().await;
 187        #[cfg(test)]
 188        tokio::task::yield_now().await;
 189        guard
 190    }
 191
 192    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 193        #[cfg(test)]
 194        tokio::task::yield_now().await;
 195        let guard = self.connection_pool.lock();
 196        ConnectionPoolGuard {
 197            guard,
 198            _not_send: PhantomData,
 199        }
 200    }
 201
 202    fn is_staff(&self) -> bool {
 203        match &self.principal {
 204            Principal::User(user) => user.admin,
 205            Principal::Impersonated { .. } => true,
 206        }
 207    }
 208
 209    fn user_id(&self) -> UserId {
 210        match &self.principal {
 211            Principal::User(user) => user.id,
 212            Principal::Impersonated { user, .. } => user.id,
 213        }
 214    }
 215
 216    pub fn email(&self) -> Option<String> {
 217        match &self.principal {
 218            Principal::User(user) => user.email_address.clone(),
 219            Principal::Impersonated { user, .. } => user.email_address.clone(),
 220        }
 221    }
 222}
 223
 224impl Debug for Session {
 225    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 226        let mut result = f.debug_struct("Session");
 227        match &self.principal {
 228            Principal::User(user) => {
 229                result.field("user", &user.github_login);
 230            }
 231            Principal::Impersonated { user, admin } => {
 232                result.field("user", &user.github_login);
 233                result.field("impersonator", &admin.github_login);
 234            }
 235        }
 236        result.field("connection_id", &self.connection_id).finish()
 237    }
 238}
 239
 240struct DbHandle(Arc<Database>);
 241
 242impl Deref for DbHandle {
 243    type Target = Database;
 244
 245    fn deref(&self) -> &Self::Target {
 246        self.0.as_ref()
 247    }
 248}
 249
 250pub struct Server {
 251    id: parking_lot::Mutex<ServerId>,
 252    peer: Arc<Peer>,
 253    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 254    app_state: Arc<AppState>,
 255    handlers: HashMap<TypeId, MessageHandler>,
 256    teardown: watch::Sender<bool>,
 257}
 258
 259pub(crate) struct ConnectionPoolGuard<'a> {
 260    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 261    _not_send: PhantomData<Rc<()>>,
 262}
 263
 264#[derive(Serialize)]
 265pub struct ServerSnapshot<'a> {
 266    peer: &'a Peer,
 267    #[serde(serialize_with = "serialize_deref")]
 268    connection_pool: ConnectionPoolGuard<'a>,
 269}
 270
 271pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 272where
 273    S: Serializer,
 274    T: Deref<Target = U>,
 275    U: Serialize,
 276{
 277    Serialize::serialize(value.deref(), serializer)
 278}
 279
 280impl Server {
 281    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 282        let mut server = Self {
 283            id: parking_lot::Mutex::new(id),
 284            peer: Peer::new(id.0 as u32),
 285            app_state: app_state.clone(),
 286            connection_pool: Default::default(),
 287            handlers: Default::default(),
 288            teardown: watch::channel(false).0,
 289        };
 290
 291        server
 292            .add_request_handler(ping)
 293            .add_request_handler(create_room)
 294            .add_request_handler(join_room)
 295            .add_request_handler(rejoin_room)
 296            .add_request_handler(leave_room)
 297            .add_request_handler(set_room_participant_role)
 298            .add_request_handler(call)
 299            .add_request_handler(cancel_call)
 300            .add_message_handler(decline_call)
 301            .add_request_handler(update_participant_location)
 302            .add_request_handler(share_project)
 303            .add_message_handler(unshare_project)
 304            .add_request_handler(join_project)
 305            .add_message_handler(leave_project)
 306            .add_request_handler(update_project)
 307            .add_request_handler(update_worktree)
 308            .add_request_handler(update_repository)
 309            .add_request_handler(remove_repository)
 310            .add_message_handler(start_language_server)
 311            .add_message_handler(update_language_server)
 312            .add_message_handler(update_diagnostic_summary)
 313            .add_message_handler(update_worktree_settings)
 314            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 316            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 317            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 318            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 319            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 320            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 321            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 322            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 323            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 324            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 325            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 326            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 327            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 328            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 329            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 330            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 331            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 332            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 333            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 334            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 335            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 336            .add_request_handler(
 337                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 338            )
 339            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 340            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 341            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 342            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 344            .add_request_handler(
 345                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 346            )
 347            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 348            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 349            .add_request_handler(
 350                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 351            )
 352            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 353            .add_request_handler(
 354                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 355            )
 356            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 357            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 358            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 359            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 360            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 361            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 362            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 363            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 364            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 365            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 366            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 367            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 368            .add_request_handler(
 369                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 370            )
 371            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 372            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 373            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 374            .add_request_handler(multi_lsp_query)
 375            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 376            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 377            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 378            .add_message_handler(create_buffer_for_peer)
 379            .add_request_handler(update_buffer)
 380            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 381            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 382            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 383            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 384            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 385            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 386            .add_message_handler(
 387                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 388            )
 389            .add_request_handler(get_users)
 390            .add_request_handler(fuzzy_search_users)
 391            .add_request_handler(request_contact)
 392            .add_request_handler(remove_contact)
 393            .add_request_handler(respond_to_contact_request)
 394            .add_message_handler(subscribe_to_channels)
 395            .add_request_handler(create_channel)
 396            .add_request_handler(delete_channel)
 397            .add_request_handler(invite_channel_member)
 398            .add_request_handler(remove_channel_member)
 399            .add_request_handler(set_channel_member_role)
 400            .add_request_handler(set_channel_visibility)
 401            .add_request_handler(rename_channel)
 402            .add_request_handler(join_channel_buffer)
 403            .add_request_handler(leave_channel_buffer)
 404            .add_message_handler(update_channel_buffer)
 405            .add_request_handler(rejoin_channel_buffers)
 406            .add_request_handler(get_channel_members)
 407            .add_request_handler(respond_to_channel_invite)
 408            .add_request_handler(join_channel)
 409            .add_request_handler(join_channel_chat)
 410            .add_message_handler(leave_channel_chat)
 411            .add_request_handler(send_channel_message)
 412            .add_request_handler(remove_channel_message)
 413            .add_request_handler(update_channel_message)
 414            .add_request_handler(get_channel_messages)
 415            .add_request_handler(get_channel_messages_by_id)
 416            .add_request_handler(get_notifications)
 417            .add_request_handler(mark_notification_as_read)
 418            .add_request_handler(move_channel)
 419            .add_request_handler(reorder_channel)
 420            .add_request_handler(follow)
 421            .add_message_handler(unfollow)
 422            .add_message_handler(update_followers)
 423            .add_request_handler(get_private_user_info)
 424            .add_request_handler(get_llm_api_token)
 425            .add_request_handler(accept_terms_of_service)
 426            .add_message_handler(acknowledge_channel_message)
 427            .add_message_handler(acknowledge_buffer_version)
 428            .add_request_handler(get_supermaven_api_key)
 429            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 430            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 431            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 432            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 433            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 434            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 435            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 436            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 437            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 438            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 439            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 440            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 441            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 442            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 443            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 444            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 445            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 446            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 447            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 448            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 449            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 450            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 451            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 452            .add_message_handler(update_context);
 453
 454        Arc::new(server)
 455    }
 456
 457    pub async fn start(&self) -> Result<()> {
 458        let server_id = *self.id.lock();
 459        let app_state = self.app_state.clone();
 460        let peer = self.peer.clone();
 461        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 462        let pool = self.connection_pool.clone();
 463        let livekit_client = self.app_state.livekit_client.clone();
 464
 465        let span = info_span!("start server");
 466        self.app_state.executor.spawn_detached(
 467            async move {
 468                tracing::info!("waiting for cleanup timeout");
 469                timeout.await;
 470                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 471
 472                app_state
 473                    .db
 474                    .delete_stale_channel_chat_participants(
 475                        &app_state.config.zed_environment,
 476                        server_id,
 477                    )
 478                    .await
 479                    .trace_err();
 480
 481                if let Some((room_ids, channel_ids)) = app_state
 482                    .db
 483                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 484                    .await
 485                    .trace_err()
 486                {
 487                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 488                    tracing::info!(
 489                        stale_channel_buffer_count = channel_ids.len(),
 490                        "retrieved stale channel buffers"
 491                    );
 492
 493                    for channel_id in channel_ids {
 494                        if let Some(refreshed_channel_buffer) = app_state
 495                            .db
 496                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 497                            .await
 498                            .trace_err()
 499                        {
 500                            for connection_id in refreshed_channel_buffer.connection_ids {
 501                                peer.send(
 502                                    connection_id,
 503                                    proto::UpdateChannelBufferCollaborators {
 504                                        channel_id: channel_id.to_proto(),
 505                                        collaborators: refreshed_channel_buffer
 506                                            .collaborators
 507                                            .clone(),
 508                                    },
 509                                )
 510                                .trace_err();
 511                            }
 512                        }
 513                    }
 514
 515                    for room_id in room_ids {
 516                        let mut contacts_to_update = HashSet::default();
 517                        let mut canceled_calls_to_user_ids = Vec::new();
 518                        let mut livekit_room = String::new();
 519                        let mut delete_livekit_room = false;
 520
 521                        if let Some(mut refreshed_room) = app_state
 522                            .db
 523                            .clear_stale_room_participants(room_id, server_id)
 524                            .await
 525                            .trace_err()
 526                        {
 527                            tracing::info!(
 528                                room_id = room_id.0,
 529                                new_participant_count = refreshed_room.room.participants.len(),
 530                                "refreshed room"
 531                            );
 532                            room_updated(&refreshed_room.room, &peer);
 533                            if let Some(channel) = refreshed_room.channel.as_ref() {
 534                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 535                            }
 536                            contacts_to_update
 537                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 538                            contacts_to_update
 539                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 540                            canceled_calls_to_user_ids =
 541                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 542                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 543                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 544                        }
 545
 546                        {
 547                            let pool = pool.lock();
 548                            for canceled_user_id in canceled_calls_to_user_ids {
 549                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 550                                    peer.send(
 551                                        connection_id,
 552                                        proto::CallCanceled {
 553                                            room_id: room_id.to_proto(),
 554                                        },
 555                                    )
 556                                    .trace_err();
 557                                }
 558                            }
 559                        }
 560
 561                        for user_id in contacts_to_update {
 562                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 563                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 564                            if let Some((busy, contacts)) = busy.zip(contacts) {
 565                                let pool = pool.lock();
 566                                let updated_contact = contact_for_user(user_id, busy, &pool);
 567                                for contact in contacts {
 568                                    if let db::Contact::Accepted {
 569                                        user_id: contact_user_id,
 570                                        ..
 571                                    } = contact
 572                                    {
 573                                        for contact_conn_id in
 574                                            pool.user_connection_ids(contact_user_id)
 575                                        {
 576                                            peer.send(
 577                                                contact_conn_id,
 578                                                proto::UpdateContacts {
 579                                                    contacts: vec![updated_contact.clone()],
 580                                                    remove_contacts: Default::default(),
 581                                                    incoming_requests: Default::default(),
 582                                                    remove_incoming_requests: Default::default(),
 583                                                    outgoing_requests: Default::default(),
 584                                                    remove_outgoing_requests: Default::default(),
 585                                                },
 586                                            )
 587                                            .trace_err();
 588                                        }
 589                                    }
 590                                }
 591                            }
 592                        }
 593
 594                        if let Some(live_kit) = livekit_client.as_ref() {
 595                            if delete_livekit_room {
 596                                live_kit.delete_room(livekit_room).await.trace_err();
 597                            }
 598                        }
 599                    }
 600                }
 601
 602                app_state
 603                    .db
 604                    .delete_stale_channel_chat_participants(
 605                        &app_state.config.zed_environment,
 606                        server_id,
 607                    )
 608                    .await
 609                    .trace_err();
 610
 611                app_state
 612                    .db
 613                    .clear_old_worktree_entries(server_id)
 614                    .await
 615                    .trace_err();
 616
 617                app_state
 618                    .db
 619                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 620                    .await
 621                    .trace_err();
 622            }
 623            .instrument(span),
 624        );
 625        Ok(())
 626    }
 627
 628    pub fn teardown(&self) {
 629        self.peer.teardown();
 630        self.connection_pool.lock().reset();
 631        let _ = self.teardown.send(true);
 632    }
 633
 634    #[cfg(test)]
 635    pub fn reset(&self, id: ServerId) {
 636        self.teardown();
 637        *self.id.lock() = id;
 638        self.peer.reset(id.0 as u32);
 639        let _ = self.teardown.send(false);
 640    }
 641
 642    #[cfg(test)]
 643    pub fn id(&self) -> ServerId {
 644        *self.id.lock()
 645    }
 646
 647    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 648    where
 649        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 650        Fut: 'static + Send + Future<Output = Result<()>>,
 651        M: EnvelopedMessage,
 652    {
 653        let prev_handler = self.handlers.insert(
 654            TypeId::of::<M>(),
 655            Box::new(move |envelope, session| {
 656                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 657                let received_at = envelope.received_at;
 658                tracing::info!("message received");
 659                let start_time = Instant::now();
 660                let future = (handler)(*envelope, session);
 661                async move {
 662                    let result = future.await;
 663                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 664                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 665                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 666
 667                    match result {
 668                        Err(error) => {
 669                            tracing::error!(
 670                                ?error,
 671                                total_duration_ms,
 672                                processing_duration_ms,
 673                                queue_duration_ms,
 674                                "error handling message"
 675                            )
 676                        }
 677                        Ok(()) => tracing::info!(
 678                            total_duration_ms,
 679                            processing_duration_ms,
 680                            queue_duration_ms,
 681                            "finished handling message"
 682                        ),
 683                    }
 684                }
 685                .boxed()
 686            }),
 687        );
 688        if prev_handler.is_some() {
 689            panic!("registered a handler for the same message twice");
 690        }
 691        self
 692    }
 693
 694    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 695    where
 696        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 697        Fut: 'static + Send + Future<Output = Result<()>>,
 698        M: EnvelopedMessage,
 699    {
 700        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 701        self
 702    }
 703
 704    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 705    where
 706        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 707        Fut: Send + Future<Output = Result<()>>,
 708        M: RequestMessage,
 709    {
 710        let handler = Arc::new(handler);
 711        self.add_handler(move |envelope, session| {
 712            let receipt = envelope.receipt();
 713            let handler = handler.clone();
 714            async move {
 715                let peer = session.peer.clone();
 716                let responded = Arc::new(AtomicBool::default());
 717                let response = Response {
 718                    peer: peer.clone(),
 719                    responded: responded.clone(),
 720                    receipt,
 721                };
 722                match (handler)(envelope.payload, response, session).await {
 723                    Ok(()) => {
 724                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 725                            Ok(())
 726                        } else {
 727                            Err(anyhow!("handler did not send a response"))?
 728                        }
 729                    }
 730                    Err(error) => {
 731                        let proto_err = match &error {
 732                            Error::Internal(err) => err.to_proto(),
 733                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 734                        };
 735                        peer.respond_with_error(receipt, proto_err)?;
 736                        Err(error)
 737                    }
 738                }
 739            }
 740        })
 741    }
 742
 743    pub fn handle_connection(
 744        self: &Arc<Self>,
 745        connection: Connection,
 746        address: String,
 747        principal: Principal,
 748        zed_version: ZedVersion,
 749        release_channel: Option<String>,
 750        user_agent: Option<String>,
 751        geoip_country_code: Option<String>,
 752        system_id: Option<String>,
 753        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 754        executor: Executor,
 755        connection_guard: Option<ConnectionGuard>,
 756    ) -> impl Future<Output = ()> + use<> {
 757        let this = self.clone();
 758        let span = info_span!("handle connection", %address,
 759            connection_id=field::Empty,
 760            user_id=field::Empty,
 761            login=field::Empty,
 762            impersonator=field::Empty,
 763            user_agent=field::Empty,
 764            geoip_country_code=field::Empty
 765        );
 766        principal.update_span(&span);
 767        if let Some(user_agent) = user_agent {
 768            span.record("user_agent", user_agent);
 769        }
 770        if let Some(release_channel) = release_channel {
 771            span.record("release_channel", release_channel);
 772        }
 773
 774        if let Some(country_code) = geoip_country_code.as_ref() {
 775            span.record("geoip_country_code", country_code);
 776        }
 777
 778        let mut teardown = self.teardown.subscribe();
 779        async move {
 780            if *teardown.borrow() {
 781                tracing::error!("server is tearing down");
 782                return;
 783            }
 784
 785            let (connection_id, handle_io, mut incoming_rx) =
 786                this.peer.add_connection(connection, {
 787                    let executor = executor.clone();
 788                    move |duration| executor.sleep(duration)
 789                });
 790            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 791
 792            tracing::info!("connection opened");
 793
 794            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 795            let http_client = match ReqwestClient::user_agent(&user_agent) {
 796                Ok(http_client) => Arc::new(http_client),
 797                Err(error) => {
 798                    tracing::error!(?error, "failed to create HTTP client");
 799                    return;
 800                }
 801            };
 802
 803            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 804                |supermaven_admin_api_key| {
 805                    Arc::new(SupermavenAdminApi::new(
 806                        supermaven_admin_api_key.to_string(),
 807                        http_client.clone(),
 808                    ))
 809                },
 810            );
 811
 812            let session = Session {
 813                principal: principal.clone(),
 814                connection_id,
 815                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 816                peer: this.peer.clone(),
 817                connection_pool: this.connection_pool.clone(),
 818                app_state: this.app_state.clone(),
 819                geoip_country_code,
 820                system_id,
 821                _executor: executor.clone(),
 822                supermaven_client,
 823            };
 824
 825            if let Err(error) = this
 826                .send_initial_client_update(
 827                    connection_id,
 828                    zed_version,
 829                    send_connection_id,
 830                    &session,
 831                )
 832                .await
 833            {
 834                tracing::error!(?error, "failed to send initial client update");
 835                return;
 836            }
 837            drop(connection_guard);
 838
 839            let handle_io = handle_io.fuse();
 840            futures::pin_mut!(handle_io);
 841
 842            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 843            // This prevents deadlocks when e.g., client A performs a request to client B and
 844            // client B performs a request to client A. If both clients stop processing further
 845            // messages until their respective request completes, they won't have a chance to
 846            // respond to the other client's request and cause a deadlock.
 847            //
 848            // This arrangement ensures we will attempt to process earlier messages first, but fall
 849            // back to processing messages arrived later in the spirit of making progress.
 850            const MAX_CONCURRENT_HANDLERS: usize = 256;
 851            let mut foreground_message_handlers = FuturesUnordered::new();
 852            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 853            let get_concurrent_handlers = {
 854                let concurrent_handlers = concurrent_handlers.clone();
 855                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 856            };
 857            loop {
 858                let next_message = async {
 859                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 860                    let message = incoming_rx.next().await;
 861                    // Cache the concurrent_handlers here, so that we know what the
 862                    // queue looks like as each handler starts
 863                    (permit, message, get_concurrent_handlers())
 864                }
 865                .fuse();
 866                futures::pin_mut!(next_message);
 867                futures::select_biased! {
 868                    _ = teardown.changed().fuse() => return,
 869                    result = handle_io => {
 870                        if let Err(error) = result {
 871                            tracing::error!(?error, "error handling I/O");
 872                        }
 873                        break;
 874                    }
 875                    _ = foreground_message_handlers.next() => {}
 876                    next_message = next_message => {
 877                        let (permit, message, concurrent_handlers) = next_message;
 878                        if let Some(message) = message {
 879                            let type_name = message.payload_type_name();
 880                            // note: we copy all the fields from the parent span so we can query them in the logs.
 881                            // (https://github.com/tokio-rs/tracing/issues/2670).
 882                            let span = tracing::info_span!("receive message",
 883                                %connection_id,
 884                                %address,
 885                                type_name,
 886                                concurrent_handlers,
 887                                user_id=field::Empty,
 888                                login=field::Empty,
 889                                impersonator=field::Empty,
 890                                multi_lsp_query_request=field::Empty,
 891                            );
 892                            principal.update_span(&span);
 893                            let span_enter = span.enter();
 894                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 895                                let is_background = message.is_background();
 896                                let handle_message = (handler)(message, session.clone());
 897                                drop(span_enter);
 898
 899                                let handle_message = async move {
 900                                    handle_message.await;
 901                                    drop(permit);
 902                                }.instrument(span);
 903                                if is_background {
 904                                    executor.spawn_detached(handle_message);
 905                                } else {
 906                                    foreground_message_handlers.push(handle_message);
 907                                }
 908                            } else {
 909                                tracing::error!("no message handler");
 910                            }
 911                        } else {
 912                            tracing::info!("connection closed");
 913                            break;
 914                        }
 915                    }
 916                }
 917            }
 918
 919            drop(foreground_message_handlers);
 920            let concurrent_handlers = get_concurrent_handlers();
 921            tracing::info!(concurrent_handlers, "signing out");
 922            if let Err(error) = connection_lost(session, teardown, executor).await {
 923                tracing::error!(?error, "error signing out");
 924            }
 925        }
 926        .instrument(span)
 927    }
 928
 929    async fn send_initial_client_update(
 930        &self,
 931        connection_id: ConnectionId,
 932        zed_version: ZedVersion,
 933        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 934        session: &Session,
 935    ) -> Result<()> {
 936        self.peer.send(
 937            connection_id,
 938            proto::Hello {
 939                peer_id: Some(connection_id.into()),
 940            },
 941        )?;
 942        tracing::info!("sent hello message");
 943        if let Some(send_connection_id) = send_connection_id.take() {
 944            let _ = send_connection_id.send(connection_id);
 945        }
 946
 947        match &session.principal {
 948            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 949                if !user.connected_once {
 950                    self.peer.send(connection_id, proto::ShowContacts {})?;
 951                    self.app_state
 952                        .db
 953                        .set_user_connected_once(user.id, true)
 954                        .await?;
 955                }
 956
 957                update_user_plan(session).await?;
 958
 959                let contacts = self.app_state.db.get_contacts(user.id).await?;
 960
 961                {
 962                    let mut pool = self.connection_pool.lock();
 963                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 964                    self.peer.send(
 965                        connection_id,
 966                        build_initial_contacts_update(contacts, &pool),
 967                    )?;
 968                }
 969
 970                if should_auto_subscribe_to_channels(zed_version) {
 971                    subscribe_user_to_channels(user.id, session).await?;
 972                }
 973
 974                if let Some(incoming_call) =
 975                    self.app_state.db.incoming_call_for_user(user.id).await?
 976                {
 977                    self.peer.send(connection_id, incoming_call)?;
 978                }
 979
 980                update_user_contacts(user.id, session).await?;
 981            }
 982        }
 983
 984        Ok(())
 985    }
 986
 987    pub async fn invite_code_redeemed(
 988        self: &Arc<Self>,
 989        inviter_id: UserId,
 990        invitee_id: UserId,
 991    ) -> Result<()> {
 992        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 993            if let Some(code) = &user.invite_code {
 994                let pool = self.connection_pool.lock();
 995                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 996                for connection_id in pool.user_connection_ids(inviter_id) {
 997                    self.peer.send(
 998                        connection_id,
 999                        proto::UpdateContacts {
1000                            contacts: vec![invitee_contact.clone()],
1001                            ..Default::default()
1002                        },
1003                    )?;
1004                    self.peer.send(
1005                        connection_id,
1006                        proto::UpdateInviteInfo {
1007                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1008                            count: user.invite_count as u32,
1009                        },
1010                    )?;
1011                }
1012            }
1013        }
1014        Ok(())
1015    }
1016
1017    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1018        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1019            if let Some(invite_code) = &user.invite_code {
1020                let pool = self.connection_pool.lock();
1021                for connection_id in pool.user_connection_ids(user_id) {
1022                    self.peer.send(
1023                        connection_id,
1024                        proto::UpdateInviteInfo {
1025                            url: format!(
1026                                "{}{}",
1027                                self.app_state.config.invite_link_prefix, invite_code
1028                            ),
1029                            count: user.invite_count as u32,
1030                        },
1031                    )?;
1032                }
1033            }
1034        }
1035        Ok(())
1036    }
1037
1038    pub async fn update_plan_for_user(
1039        self: &Arc<Self>,
1040        user_id: UserId,
1041        update_user_plan: proto::UpdateUserPlan,
1042    ) -> Result<()> {
1043        let pool = self.connection_pool.lock();
1044        for connection_id in pool.user_connection_ids(user_id) {
1045            self.peer
1046                .send(connection_id, update_user_plan.clone())
1047                .trace_err();
1048        }
1049
1050        Ok(())
1051    }
1052
1053    /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1054    /// message on the Collab server.
1055    ///
1056    /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1057    pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1058        let user = self
1059            .app_state
1060            .db
1061            .get_user_by_id(user_id)
1062            .await?
1063            .context("user not found")?;
1064
1065        let update_user_plan = make_update_user_plan_message(
1066            &user,
1067            user.admin,
1068            &self.app_state.db,
1069            self.app_state.llm_db.clone(),
1070        )
1071        .await?;
1072
1073        self.update_plan_for_user(user_id, update_user_plan).await
1074    }
1075
1076    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1077        let pool = self.connection_pool.lock();
1078        for connection_id in pool.user_connection_ids(user_id) {
1079            self.peer
1080                .send(connection_id, proto::RefreshLlmToken {})
1081                .trace_err();
1082        }
1083    }
1084
1085    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1086        ServerSnapshot {
1087            connection_pool: ConnectionPoolGuard {
1088                guard: self.connection_pool.lock(),
1089                _not_send: PhantomData,
1090            },
1091            peer: &self.peer,
1092        }
1093    }
1094}
1095
1096impl Deref for ConnectionPoolGuard<'_> {
1097    type Target = ConnectionPool;
1098
1099    fn deref(&self) -> &Self::Target {
1100        &self.guard
1101    }
1102}
1103
1104impl DerefMut for ConnectionPoolGuard<'_> {
1105    fn deref_mut(&mut self) -> &mut Self::Target {
1106        &mut self.guard
1107    }
1108}
1109
1110impl Drop for ConnectionPoolGuard<'_> {
1111    fn drop(&mut self) {
1112        #[cfg(test)]
1113        self.check_invariants();
1114    }
1115}
1116
1117fn broadcast<F>(
1118    sender_id: Option<ConnectionId>,
1119    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1120    mut f: F,
1121) where
1122    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1123{
1124    for receiver_id in receiver_ids {
1125        if Some(receiver_id) != sender_id {
1126            if let Err(error) = f(receiver_id) {
1127                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1128            }
1129        }
1130    }
1131}
1132
1133pub struct ProtocolVersion(u32);
1134
1135impl Header for ProtocolVersion {
1136    fn name() -> &'static HeaderName {
1137        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1138        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1139    }
1140
1141    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1142    where
1143        Self: Sized,
1144        I: Iterator<Item = &'i axum::http::HeaderValue>,
1145    {
1146        let version = values
1147            .next()
1148            .ok_or_else(axum::headers::Error::invalid)?
1149            .to_str()
1150            .map_err(|_| axum::headers::Error::invalid())?
1151            .parse()
1152            .map_err(|_| axum::headers::Error::invalid())?;
1153        Ok(Self(version))
1154    }
1155
1156    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1157        values.extend([self.0.to_string().parse().unwrap()]);
1158    }
1159}
1160
1161pub struct AppVersionHeader(SemanticVersion);
1162impl Header for AppVersionHeader {
1163    fn name() -> &'static HeaderName {
1164        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1165        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1166    }
1167
1168    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1169    where
1170        Self: Sized,
1171        I: Iterator<Item = &'i axum::http::HeaderValue>,
1172    {
1173        let version = values
1174            .next()
1175            .ok_or_else(axum::headers::Error::invalid)?
1176            .to_str()
1177            .map_err(|_| axum::headers::Error::invalid())?
1178            .parse()
1179            .map_err(|_| axum::headers::Error::invalid())?;
1180        Ok(Self(version))
1181    }
1182
1183    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1184        values.extend([self.0.to_string().parse().unwrap()]);
1185    }
1186}
1187
1188#[derive(Debug)]
1189pub struct ReleaseChannelHeader(String);
1190
1191impl Header for ReleaseChannelHeader {
1192    fn name() -> &'static HeaderName {
1193        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1194        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1195    }
1196
1197    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1198    where
1199        Self: Sized,
1200        I: Iterator<Item = &'i axum::http::HeaderValue>,
1201    {
1202        Ok(Self(
1203            values
1204                .next()
1205                .ok_or_else(axum::headers::Error::invalid)?
1206                .to_str()
1207                .map_err(|_| axum::headers::Error::invalid())?
1208                .to_owned(),
1209        ))
1210    }
1211
1212    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1213        values.extend([self.0.parse().unwrap()]);
1214    }
1215}
1216
1217pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1218    Router::new()
1219        .route("/rpc", get(handle_websocket_request))
1220        .layer(
1221            ServiceBuilder::new()
1222                .layer(Extension(server.app_state.clone()))
1223                .layer(middleware::from_fn(auth::validate_header)),
1224        )
1225        .route("/metrics", get(handle_metrics))
1226        .layer(Extension(server))
1227}
1228
1229pub async fn handle_websocket_request(
1230    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1231    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1232    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1233    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1234    Extension(server): Extension<Arc<Server>>,
1235    Extension(principal): Extension<Principal>,
1236    user_agent: Option<TypedHeader<UserAgent>>,
1237    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1238    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1239    ws: WebSocketUpgrade,
1240) -> axum::response::Response {
1241    if protocol_version != rpc::PROTOCOL_VERSION {
1242        return (
1243            StatusCode::UPGRADE_REQUIRED,
1244            "client must be upgraded".to_string(),
1245        )
1246            .into_response();
1247    }
1248
1249    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1250        return (
1251            StatusCode::UPGRADE_REQUIRED,
1252            "no version header found".to_string(),
1253        )
1254            .into_response();
1255    };
1256
1257    let release_channel = release_channel_header.map(|header| header.0.0);
1258
1259    if !version.can_collaborate() {
1260        return (
1261            StatusCode::UPGRADE_REQUIRED,
1262            "client must be upgraded".to_string(),
1263        )
1264            .into_response();
1265    }
1266
1267    let socket_address = socket_address.to_string();
1268
1269    // Acquire connection guard before WebSocket upgrade
1270    let connection_guard = match ConnectionGuard::try_acquire() {
1271        Ok(guard) => guard,
1272        Err(()) => {
1273            return (
1274                StatusCode::SERVICE_UNAVAILABLE,
1275                "Too many concurrent connections",
1276            )
1277                .into_response();
1278        }
1279    };
1280
1281    ws.on_upgrade(move |socket| {
1282        let socket = socket
1283            .map_ok(to_tungstenite_message)
1284            .err_into()
1285            .with(|message| async move { to_axum_message(message) });
1286        let connection = Connection::new(Box::pin(socket));
1287        async move {
1288            server
1289                .handle_connection(
1290                    connection,
1291                    socket_address,
1292                    principal,
1293                    version,
1294                    release_channel,
1295                    user_agent.map(|header| header.to_string()),
1296                    country_code_header.map(|header| header.to_string()),
1297                    system_id_header.map(|header| header.to_string()),
1298                    None,
1299                    Executor::Production,
1300                    Some(connection_guard),
1301                )
1302                .await;
1303        }
1304    })
1305}
1306
1307pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1308    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1309    let connections_metric = CONNECTIONS_METRIC
1310        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1311
1312    let connections = server
1313        .connection_pool
1314        .lock()
1315        .connections()
1316        .filter(|connection| !connection.admin)
1317        .count();
1318    connections_metric.set(connections as _);
1319
1320    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1321    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1322        register_int_gauge!(
1323            "shared_projects",
1324            "number of open projects with one or more guests"
1325        )
1326        .unwrap()
1327    });
1328
1329    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1330    shared_projects_metric.set(shared_projects as _);
1331
1332    let encoder = prometheus::TextEncoder::new();
1333    let metric_families = prometheus::gather();
1334    let encoded_metrics = encoder
1335        .encode_to_string(&metric_families)
1336        .map_err(|err| anyhow!("{err}"))?;
1337    Ok(encoded_metrics)
1338}
1339
1340#[instrument(err, skip(executor))]
1341async fn connection_lost(
1342    session: Session,
1343    mut teardown: watch::Receiver<bool>,
1344    executor: Executor,
1345) -> Result<()> {
1346    session.peer.disconnect(session.connection_id);
1347    session
1348        .connection_pool()
1349        .await
1350        .remove_connection(session.connection_id)?;
1351
1352    session
1353        .db()
1354        .await
1355        .connection_lost(session.connection_id)
1356        .await
1357        .trace_err();
1358
1359    futures::select_biased! {
1360        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1361
1362            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1363            leave_room_for_session(&session, session.connection_id).await.trace_err();
1364            leave_channel_buffers_for_session(&session)
1365                .await
1366                .trace_err();
1367
1368            if !session
1369                .connection_pool()
1370                .await
1371                .is_user_online(session.user_id())
1372            {
1373                let db = session.db().await;
1374                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1375                    room_updated(&room, &session.peer);
1376                }
1377            }
1378
1379            update_user_contacts(session.user_id(), &session).await?;
1380        },
1381        _ = teardown.changed().fuse() => {}
1382    }
1383
1384    Ok(())
1385}
1386
1387/// Acknowledges a ping from a client, used to keep the connection alive.
1388async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1389    response.send(proto::Ack {})?;
1390    Ok(())
1391}
1392
1393/// Creates a new room for calling (outside of channels)
1394async fn create_room(
1395    _request: proto::CreateRoom,
1396    response: Response<proto::CreateRoom>,
1397    session: Session,
1398) -> Result<()> {
1399    let livekit_room = nanoid::nanoid!(30);
1400
1401    let live_kit_connection_info = util::maybe!(async {
1402        let live_kit = session.app_state.livekit_client.as_ref();
1403        let live_kit = live_kit?;
1404        let user_id = session.user_id().to_string();
1405
1406        let token = live_kit
1407            .room_token(&livekit_room, &user_id.to_string())
1408            .trace_err()?;
1409
1410        Some(proto::LiveKitConnectionInfo {
1411            server_url: live_kit.url().into(),
1412            token,
1413            can_publish: true,
1414        })
1415    })
1416    .await;
1417
1418    let room = session
1419        .db()
1420        .await
1421        .create_room(session.user_id(), session.connection_id, &livekit_room)
1422        .await?;
1423
1424    response.send(proto::CreateRoomResponse {
1425        room: Some(room.clone()),
1426        live_kit_connection_info,
1427    })?;
1428
1429    update_user_contacts(session.user_id(), &session).await?;
1430    Ok(())
1431}
1432
1433/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1434async fn join_room(
1435    request: proto::JoinRoom,
1436    response: Response<proto::JoinRoom>,
1437    session: Session,
1438) -> Result<()> {
1439    let room_id = RoomId::from_proto(request.id);
1440
1441    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1442
1443    if let Some(channel_id) = channel_id {
1444        return join_channel_internal(channel_id, Box::new(response), session).await;
1445    }
1446
1447    let joined_room = {
1448        let room = session
1449            .db()
1450            .await
1451            .join_room(room_id, session.user_id(), session.connection_id)
1452            .await?;
1453        room_updated(&room.room, &session.peer);
1454        room.into_inner()
1455    };
1456
1457    for connection_id in session
1458        .connection_pool()
1459        .await
1460        .user_connection_ids(session.user_id())
1461    {
1462        session
1463            .peer
1464            .send(
1465                connection_id,
1466                proto::CallCanceled {
1467                    room_id: room_id.to_proto(),
1468                },
1469            )
1470            .trace_err();
1471    }
1472
1473    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1474    {
1475        live_kit
1476            .room_token(
1477                &joined_room.room.livekit_room,
1478                &session.user_id().to_string(),
1479            )
1480            .trace_err()
1481            .map(|token| proto::LiveKitConnectionInfo {
1482                server_url: live_kit.url().into(),
1483                token,
1484                can_publish: true,
1485            })
1486    } else {
1487        None
1488    };
1489
1490    response.send(proto::JoinRoomResponse {
1491        room: Some(joined_room.room),
1492        channel_id: None,
1493        live_kit_connection_info,
1494    })?;
1495
1496    update_user_contacts(session.user_id(), &session).await?;
1497    Ok(())
1498}
1499
1500/// Rejoin room is used to reconnect to a room after connection errors.
1501async fn rejoin_room(
1502    request: proto::RejoinRoom,
1503    response: Response<proto::RejoinRoom>,
1504    session: Session,
1505) -> Result<()> {
1506    let room;
1507    let channel;
1508    {
1509        let mut rejoined_room = session
1510            .db()
1511            .await
1512            .rejoin_room(request, session.user_id(), session.connection_id)
1513            .await?;
1514
1515        response.send(proto::RejoinRoomResponse {
1516            room: Some(rejoined_room.room.clone()),
1517            reshared_projects: rejoined_room
1518                .reshared_projects
1519                .iter()
1520                .map(|project| proto::ResharedProject {
1521                    id: project.id.to_proto(),
1522                    collaborators: project
1523                        .collaborators
1524                        .iter()
1525                        .map(|collaborator| collaborator.to_proto())
1526                        .collect(),
1527                })
1528                .collect(),
1529            rejoined_projects: rejoined_room
1530                .rejoined_projects
1531                .iter()
1532                .map(|rejoined_project| rejoined_project.to_proto())
1533                .collect(),
1534        })?;
1535        room_updated(&rejoined_room.room, &session.peer);
1536
1537        for project in &rejoined_room.reshared_projects {
1538            for collaborator in &project.collaborators {
1539                session
1540                    .peer
1541                    .send(
1542                        collaborator.connection_id,
1543                        proto::UpdateProjectCollaborator {
1544                            project_id: project.id.to_proto(),
1545                            old_peer_id: Some(project.old_connection_id.into()),
1546                            new_peer_id: Some(session.connection_id.into()),
1547                        },
1548                    )
1549                    .trace_err();
1550            }
1551
1552            broadcast(
1553                Some(session.connection_id),
1554                project
1555                    .collaborators
1556                    .iter()
1557                    .map(|collaborator| collaborator.connection_id),
1558                |connection_id| {
1559                    session.peer.forward_send(
1560                        session.connection_id,
1561                        connection_id,
1562                        proto::UpdateProject {
1563                            project_id: project.id.to_proto(),
1564                            worktrees: project.worktrees.clone(),
1565                        },
1566                    )
1567                },
1568            );
1569        }
1570
1571        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1572
1573        let rejoined_room = rejoined_room.into_inner();
1574
1575        room = rejoined_room.room;
1576        channel = rejoined_room.channel;
1577    }
1578
1579    if let Some(channel) = channel {
1580        channel_updated(
1581            &channel,
1582            &room,
1583            &session.peer,
1584            &*session.connection_pool().await,
1585        );
1586    }
1587
1588    update_user_contacts(session.user_id(), &session).await?;
1589    Ok(())
1590}
1591
1592fn notify_rejoined_projects(
1593    rejoined_projects: &mut Vec<RejoinedProject>,
1594    session: &Session,
1595) -> Result<()> {
1596    for project in rejoined_projects.iter() {
1597        for collaborator in &project.collaborators {
1598            session
1599                .peer
1600                .send(
1601                    collaborator.connection_id,
1602                    proto::UpdateProjectCollaborator {
1603                        project_id: project.id.to_proto(),
1604                        old_peer_id: Some(project.old_connection_id.into()),
1605                        new_peer_id: Some(session.connection_id.into()),
1606                    },
1607                )
1608                .trace_err();
1609        }
1610    }
1611
1612    for project in rejoined_projects {
1613        for worktree in mem::take(&mut project.worktrees) {
1614            // Stream this worktree's entries.
1615            let message = proto::UpdateWorktree {
1616                project_id: project.id.to_proto(),
1617                worktree_id: worktree.id,
1618                abs_path: worktree.abs_path.clone(),
1619                root_name: worktree.root_name,
1620                updated_entries: worktree.updated_entries,
1621                removed_entries: worktree.removed_entries,
1622                scan_id: worktree.scan_id,
1623                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1624                updated_repositories: worktree.updated_repositories,
1625                removed_repositories: worktree.removed_repositories,
1626            };
1627            for update in proto::split_worktree_update(message) {
1628                session.peer.send(session.connection_id, update)?;
1629            }
1630
1631            // Stream this worktree's diagnostics.
1632            for summary in worktree.diagnostic_summaries {
1633                session.peer.send(
1634                    session.connection_id,
1635                    proto::UpdateDiagnosticSummary {
1636                        project_id: project.id.to_proto(),
1637                        worktree_id: worktree.id,
1638                        summary: Some(summary),
1639                    },
1640                )?;
1641            }
1642
1643            for settings_file in worktree.settings_files {
1644                session.peer.send(
1645                    session.connection_id,
1646                    proto::UpdateWorktreeSettings {
1647                        project_id: project.id.to_proto(),
1648                        worktree_id: worktree.id,
1649                        path: settings_file.path,
1650                        content: Some(settings_file.content),
1651                        kind: Some(settings_file.kind.to_proto().into()),
1652                    },
1653                )?;
1654            }
1655        }
1656
1657        for repository in mem::take(&mut project.updated_repositories) {
1658            for update in split_repository_update(repository) {
1659                session.peer.send(session.connection_id, update)?;
1660            }
1661        }
1662
1663        for id in mem::take(&mut project.removed_repositories) {
1664            session.peer.send(
1665                session.connection_id,
1666                proto::RemoveRepository {
1667                    project_id: project.id.to_proto(),
1668                    id,
1669                },
1670            )?;
1671        }
1672    }
1673
1674    Ok(())
1675}
1676
1677/// leave room disconnects from the room.
1678async fn leave_room(
1679    _: proto::LeaveRoom,
1680    response: Response<proto::LeaveRoom>,
1681    session: Session,
1682) -> Result<()> {
1683    leave_room_for_session(&session, session.connection_id).await?;
1684    response.send(proto::Ack {})?;
1685    Ok(())
1686}
1687
1688/// Updates the permissions of someone else in the room.
1689async fn set_room_participant_role(
1690    request: proto::SetRoomParticipantRole,
1691    response: Response<proto::SetRoomParticipantRole>,
1692    session: Session,
1693) -> Result<()> {
1694    let user_id = UserId::from_proto(request.user_id);
1695    let role = ChannelRole::from(request.role());
1696
1697    let (livekit_room, can_publish) = {
1698        let room = session
1699            .db()
1700            .await
1701            .set_room_participant_role(
1702                session.user_id(),
1703                RoomId::from_proto(request.room_id),
1704                user_id,
1705                role,
1706            )
1707            .await?;
1708
1709        let livekit_room = room.livekit_room.clone();
1710        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1711        room_updated(&room, &session.peer);
1712        (livekit_room, can_publish)
1713    };
1714
1715    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1716        live_kit
1717            .update_participant(
1718                livekit_room.clone(),
1719                request.user_id.to_string(),
1720                livekit_api::proto::ParticipantPermission {
1721                    can_subscribe: true,
1722                    can_publish,
1723                    can_publish_data: can_publish,
1724                    hidden: false,
1725                    recorder: false,
1726                },
1727            )
1728            .await
1729            .trace_err();
1730    }
1731
1732    response.send(proto::Ack {})?;
1733    Ok(())
1734}
1735
1736/// Call someone else into the current room
1737async fn call(
1738    request: proto::Call,
1739    response: Response<proto::Call>,
1740    session: Session,
1741) -> Result<()> {
1742    let room_id = RoomId::from_proto(request.room_id);
1743    let calling_user_id = session.user_id();
1744    let calling_connection_id = session.connection_id;
1745    let called_user_id = UserId::from_proto(request.called_user_id);
1746    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1747    if !session
1748        .db()
1749        .await
1750        .has_contact(calling_user_id, called_user_id)
1751        .await?
1752    {
1753        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1754    }
1755
1756    let incoming_call = {
1757        let (room, incoming_call) = &mut *session
1758            .db()
1759            .await
1760            .call(
1761                room_id,
1762                calling_user_id,
1763                calling_connection_id,
1764                called_user_id,
1765                initial_project_id,
1766            )
1767            .await?;
1768        room_updated(room, &session.peer);
1769        mem::take(incoming_call)
1770    };
1771    update_user_contacts(called_user_id, &session).await?;
1772
1773    let mut calls = session
1774        .connection_pool()
1775        .await
1776        .user_connection_ids(called_user_id)
1777        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1778        .collect::<FuturesUnordered<_>>();
1779
1780    while let Some(call_response) = calls.next().await {
1781        match call_response.as_ref() {
1782            Ok(_) => {
1783                response.send(proto::Ack {})?;
1784                return Ok(());
1785            }
1786            Err(_) => {
1787                call_response.trace_err();
1788            }
1789        }
1790    }
1791
1792    {
1793        let room = session
1794            .db()
1795            .await
1796            .call_failed(room_id, called_user_id)
1797            .await?;
1798        room_updated(&room, &session.peer);
1799    }
1800    update_user_contacts(called_user_id, &session).await?;
1801
1802    Err(anyhow!("failed to ring user"))?
1803}
1804
1805/// Cancel an outgoing call.
1806async fn cancel_call(
1807    request: proto::CancelCall,
1808    response: Response<proto::CancelCall>,
1809    session: Session,
1810) -> Result<()> {
1811    let called_user_id = UserId::from_proto(request.called_user_id);
1812    let room_id = RoomId::from_proto(request.room_id);
1813    {
1814        let room = session
1815            .db()
1816            .await
1817            .cancel_call(room_id, session.connection_id, called_user_id)
1818            .await?;
1819        room_updated(&room, &session.peer);
1820    }
1821
1822    for connection_id in session
1823        .connection_pool()
1824        .await
1825        .user_connection_ids(called_user_id)
1826    {
1827        session
1828            .peer
1829            .send(
1830                connection_id,
1831                proto::CallCanceled {
1832                    room_id: room_id.to_proto(),
1833                },
1834            )
1835            .trace_err();
1836    }
1837    response.send(proto::Ack {})?;
1838
1839    update_user_contacts(called_user_id, &session).await?;
1840    Ok(())
1841}
1842
1843/// Decline an incoming call.
1844async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1845    let room_id = RoomId::from_proto(message.room_id);
1846    {
1847        let room = session
1848            .db()
1849            .await
1850            .decline_call(Some(room_id), session.user_id())
1851            .await?
1852            .context("declining call")?;
1853        room_updated(&room, &session.peer);
1854    }
1855
1856    for connection_id in session
1857        .connection_pool()
1858        .await
1859        .user_connection_ids(session.user_id())
1860    {
1861        session
1862            .peer
1863            .send(
1864                connection_id,
1865                proto::CallCanceled {
1866                    room_id: room_id.to_proto(),
1867                },
1868            )
1869            .trace_err();
1870    }
1871    update_user_contacts(session.user_id(), &session).await?;
1872    Ok(())
1873}
1874
1875/// Updates other participants in the room with your current location.
1876async fn update_participant_location(
1877    request: proto::UpdateParticipantLocation,
1878    response: Response<proto::UpdateParticipantLocation>,
1879    session: Session,
1880) -> Result<()> {
1881    let room_id = RoomId::from_proto(request.room_id);
1882    let location = request.location.context("invalid location")?;
1883
1884    let db = session.db().await;
1885    let room = db
1886        .update_room_participant_location(room_id, session.connection_id, location)
1887        .await?;
1888
1889    room_updated(&room, &session.peer);
1890    response.send(proto::Ack {})?;
1891    Ok(())
1892}
1893
1894/// Share a project into the room.
1895async fn share_project(
1896    request: proto::ShareProject,
1897    response: Response<proto::ShareProject>,
1898    session: Session,
1899) -> Result<()> {
1900    let (project_id, room) = &*session
1901        .db()
1902        .await
1903        .share_project(
1904            RoomId::from_proto(request.room_id),
1905            session.connection_id,
1906            &request.worktrees,
1907            request.is_ssh_project,
1908        )
1909        .await?;
1910    response.send(proto::ShareProjectResponse {
1911        project_id: project_id.to_proto(),
1912    })?;
1913    room_updated(room, &session.peer);
1914
1915    Ok(())
1916}
1917
1918/// Unshare a project from the room.
1919async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1920    let project_id = ProjectId::from_proto(message.project_id);
1921    unshare_project_internal(project_id, session.connection_id, &session).await
1922}
1923
1924async fn unshare_project_internal(
1925    project_id: ProjectId,
1926    connection_id: ConnectionId,
1927    session: &Session,
1928) -> Result<()> {
1929    let delete = {
1930        let room_guard = session
1931            .db()
1932            .await
1933            .unshare_project(project_id, connection_id)
1934            .await?;
1935
1936        let (delete, room, guest_connection_ids) = &*room_guard;
1937
1938        let message = proto::UnshareProject {
1939            project_id: project_id.to_proto(),
1940        };
1941
1942        broadcast(
1943            Some(connection_id),
1944            guest_connection_ids.iter().copied(),
1945            |conn_id| session.peer.send(conn_id, message.clone()),
1946        );
1947        if let Some(room) = room {
1948            room_updated(room, &session.peer);
1949        }
1950
1951        *delete
1952    };
1953
1954    if delete {
1955        let db = session.db().await;
1956        db.delete_project(project_id).await?;
1957    }
1958
1959    Ok(())
1960}
1961
1962/// Join someone elses shared project.
1963async fn join_project(
1964    request: proto::JoinProject,
1965    response: Response<proto::JoinProject>,
1966    session: Session,
1967) -> Result<()> {
1968    let project_id = ProjectId::from_proto(request.project_id);
1969
1970    tracing::info!(%project_id, "join project");
1971
1972    let db = session.db().await;
1973    let (project, replica_id) = &mut *db
1974        .join_project(
1975            project_id,
1976            session.connection_id,
1977            session.user_id(),
1978            request.committer_name.clone(),
1979            request.committer_email.clone(),
1980        )
1981        .await?;
1982    drop(db);
1983    tracing::info!(%project_id, "join remote project");
1984    let collaborators = project
1985        .collaborators
1986        .iter()
1987        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1988        .map(|collaborator| collaborator.to_proto())
1989        .collect::<Vec<_>>();
1990    let project_id = project.id;
1991    let guest_user_id = session.user_id();
1992
1993    let worktrees = project
1994        .worktrees
1995        .iter()
1996        .map(|(id, worktree)| proto::WorktreeMetadata {
1997            id: *id,
1998            root_name: worktree.root_name.clone(),
1999            visible: worktree.visible,
2000            abs_path: worktree.abs_path.clone(),
2001        })
2002        .collect::<Vec<_>>();
2003
2004    let add_project_collaborator = proto::AddProjectCollaborator {
2005        project_id: project_id.to_proto(),
2006        collaborator: Some(proto::Collaborator {
2007            peer_id: Some(session.connection_id.into()),
2008            replica_id: replica_id.0 as u32,
2009            user_id: guest_user_id.to_proto(),
2010            is_host: false,
2011            committer_name: request.committer_name.clone(),
2012            committer_email: request.committer_email.clone(),
2013        }),
2014    };
2015
2016    for collaborator in &collaborators {
2017        session
2018            .peer
2019            .send(
2020                collaborator.peer_id.unwrap().into(),
2021                add_project_collaborator.clone(),
2022            )
2023            .trace_err();
2024    }
2025
2026    // First, we send the metadata associated with each worktree.
2027    let (language_servers, language_server_capabilities) = project
2028        .language_servers
2029        .clone()
2030        .into_iter()
2031        .map(|server| (server.server, server.capabilities))
2032        .unzip();
2033    response.send(proto::JoinProjectResponse {
2034        project_id: project.id.0 as u64,
2035        worktrees: worktrees.clone(),
2036        replica_id: replica_id.0 as u32,
2037        collaborators: collaborators.clone(),
2038        language_servers,
2039        language_server_capabilities,
2040        role: project.role.into(),
2041    })?;
2042
2043    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2044        // Stream this worktree's entries.
2045        let message = proto::UpdateWorktree {
2046            project_id: project_id.to_proto(),
2047            worktree_id,
2048            abs_path: worktree.abs_path.clone(),
2049            root_name: worktree.root_name,
2050            updated_entries: worktree.entries,
2051            removed_entries: Default::default(),
2052            scan_id: worktree.scan_id,
2053            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2054            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2055            removed_repositories: Default::default(),
2056        };
2057        for update in proto::split_worktree_update(message) {
2058            session.peer.send(session.connection_id, update.clone())?;
2059        }
2060
2061        // Stream this worktree's diagnostics.
2062        for summary in worktree.diagnostic_summaries {
2063            session.peer.send(
2064                session.connection_id,
2065                proto::UpdateDiagnosticSummary {
2066                    project_id: project_id.to_proto(),
2067                    worktree_id: worktree.id,
2068                    summary: Some(summary),
2069                },
2070            )?;
2071        }
2072
2073        for settings_file in worktree.settings_files {
2074            session.peer.send(
2075                session.connection_id,
2076                proto::UpdateWorktreeSettings {
2077                    project_id: project_id.to_proto(),
2078                    worktree_id: worktree.id,
2079                    path: settings_file.path,
2080                    content: Some(settings_file.content),
2081                    kind: Some(settings_file.kind.to_proto() as i32),
2082                },
2083            )?;
2084        }
2085    }
2086
2087    for repository in mem::take(&mut project.repositories) {
2088        for update in split_repository_update(repository) {
2089            session.peer.send(session.connection_id, update)?;
2090        }
2091    }
2092
2093    for language_server in &project.language_servers {
2094        session.peer.send(
2095            session.connection_id,
2096            proto::UpdateLanguageServer {
2097                project_id: project_id.to_proto(),
2098                server_name: Some(language_server.server.name.clone()),
2099                language_server_id: language_server.server.id,
2100                variant: Some(
2101                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2102                        proto::LspDiskBasedDiagnosticsUpdated {},
2103                    ),
2104                ),
2105            },
2106        )?;
2107    }
2108
2109    Ok(())
2110}
2111
2112/// Leave someone elses shared project.
2113async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2114    let sender_id = session.connection_id;
2115    let project_id = ProjectId::from_proto(request.project_id);
2116    let db = session.db().await;
2117
2118    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2119    tracing::info!(
2120        %project_id,
2121        "leave project"
2122    );
2123
2124    project_left(project, &session);
2125    if let Some(room) = room {
2126        room_updated(room, &session.peer);
2127    }
2128
2129    Ok(())
2130}
2131
2132/// Updates other participants with changes to the project
2133async fn update_project(
2134    request: proto::UpdateProject,
2135    response: Response<proto::UpdateProject>,
2136    session: Session,
2137) -> Result<()> {
2138    let project_id = ProjectId::from_proto(request.project_id);
2139    let (room, guest_connection_ids) = &*session
2140        .db()
2141        .await
2142        .update_project(project_id, session.connection_id, &request.worktrees)
2143        .await?;
2144    broadcast(
2145        Some(session.connection_id),
2146        guest_connection_ids.iter().copied(),
2147        |connection_id| {
2148            session
2149                .peer
2150                .forward_send(session.connection_id, connection_id, request.clone())
2151        },
2152    );
2153    if let Some(room) = room {
2154        room_updated(room, &session.peer);
2155    }
2156    response.send(proto::Ack {})?;
2157
2158    Ok(())
2159}
2160
2161/// Updates other participants with changes to the worktree
2162async fn update_worktree(
2163    request: proto::UpdateWorktree,
2164    response: Response<proto::UpdateWorktree>,
2165    session: Session,
2166) -> Result<()> {
2167    let guest_connection_ids = session
2168        .db()
2169        .await
2170        .update_worktree(&request, session.connection_id)
2171        .await?;
2172
2173    broadcast(
2174        Some(session.connection_id),
2175        guest_connection_ids.iter().copied(),
2176        |connection_id| {
2177            session
2178                .peer
2179                .forward_send(session.connection_id, connection_id, request.clone())
2180        },
2181    );
2182    response.send(proto::Ack {})?;
2183    Ok(())
2184}
2185
2186async fn update_repository(
2187    request: proto::UpdateRepository,
2188    response: Response<proto::UpdateRepository>,
2189    session: Session,
2190) -> Result<()> {
2191    let guest_connection_ids = session
2192        .db()
2193        .await
2194        .update_repository(&request, session.connection_id)
2195        .await?;
2196
2197    broadcast(
2198        Some(session.connection_id),
2199        guest_connection_ids.iter().copied(),
2200        |connection_id| {
2201            session
2202                .peer
2203                .forward_send(session.connection_id, connection_id, request.clone())
2204        },
2205    );
2206    response.send(proto::Ack {})?;
2207    Ok(())
2208}
2209
2210async fn remove_repository(
2211    request: proto::RemoveRepository,
2212    response: Response<proto::RemoveRepository>,
2213    session: Session,
2214) -> Result<()> {
2215    let guest_connection_ids = session
2216        .db()
2217        .await
2218        .remove_repository(&request, session.connection_id)
2219        .await?;
2220
2221    broadcast(
2222        Some(session.connection_id),
2223        guest_connection_ids.iter().copied(),
2224        |connection_id| {
2225            session
2226                .peer
2227                .forward_send(session.connection_id, connection_id, request.clone())
2228        },
2229    );
2230    response.send(proto::Ack {})?;
2231    Ok(())
2232}
2233
2234/// Updates other participants with changes to the diagnostics
2235async fn update_diagnostic_summary(
2236    message: proto::UpdateDiagnosticSummary,
2237    session: Session,
2238) -> Result<()> {
2239    let guest_connection_ids = session
2240        .db()
2241        .await
2242        .update_diagnostic_summary(&message, session.connection_id)
2243        .await?;
2244
2245    broadcast(
2246        Some(session.connection_id),
2247        guest_connection_ids.iter().copied(),
2248        |connection_id| {
2249            session
2250                .peer
2251                .forward_send(session.connection_id, connection_id, message.clone())
2252        },
2253    );
2254
2255    Ok(())
2256}
2257
2258/// Updates other participants with changes to the worktree settings
2259async fn update_worktree_settings(
2260    message: proto::UpdateWorktreeSettings,
2261    session: Session,
2262) -> Result<()> {
2263    let guest_connection_ids = session
2264        .db()
2265        .await
2266        .update_worktree_settings(&message, session.connection_id)
2267        .await?;
2268
2269    broadcast(
2270        Some(session.connection_id),
2271        guest_connection_ids.iter().copied(),
2272        |connection_id| {
2273            session
2274                .peer
2275                .forward_send(session.connection_id, connection_id, message.clone())
2276        },
2277    );
2278
2279    Ok(())
2280}
2281
2282/// Notify other participants that a language server has started.
2283async fn start_language_server(
2284    request: proto::StartLanguageServer,
2285    session: Session,
2286) -> Result<()> {
2287    let guest_connection_ids = session
2288        .db()
2289        .await
2290        .start_language_server(&request, session.connection_id)
2291        .await?;
2292
2293    broadcast(
2294        Some(session.connection_id),
2295        guest_connection_ids.iter().copied(),
2296        |connection_id| {
2297            session
2298                .peer
2299                .forward_send(session.connection_id, connection_id, request.clone())
2300        },
2301    );
2302    Ok(())
2303}
2304
2305/// Notify other participants that a language server has changed.
2306async fn update_language_server(
2307    request: proto::UpdateLanguageServer,
2308    session: Session,
2309) -> Result<()> {
2310    let project_id = ProjectId::from_proto(request.project_id);
2311    let db = session.db().await;
2312
2313    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2314    {
2315        if let Some(capabilities) = update.capabilities.clone() {
2316            db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2317                .await?;
2318        }
2319    }
2320
2321    let project_connection_ids = db
2322        .project_connection_ids(project_id, session.connection_id, true)
2323        .await?;
2324    broadcast(
2325        Some(session.connection_id),
2326        project_connection_ids.iter().copied(),
2327        |connection_id| {
2328            session
2329                .peer
2330                .forward_send(session.connection_id, connection_id, request.clone())
2331        },
2332    );
2333    Ok(())
2334}
2335
2336/// forward a project request to the host. These requests should be read only
2337/// as guests are allowed to send them.
2338async fn forward_read_only_project_request<T>(
2339    request: T,
2340    response: Response<T>,
2341    session: Session,
2342) -> Result<()>
2343where
2344    T: EntityMessage + RequestMessage,
2345{
2346    let project_id = ProjectId::from_proto(request.remote_entity_id());
2347    let host_connection_id = session
2348        .db()
2349        .await
2350        .host_for_read_only_project_request(project_id, session.connection_id)
2351        .await?;
2352    let payload = session
2353        .peer
2354        .forward_request(session.connection_id, host_connection_id, request)
2355        .await?;
2356    response.send(payload)?;
2357    Ok(())
2358}
2359
2360/// forward a project request to the host. These requests are disallowed
2361/// for guests.
2362async fn forward_mutating_project_request<T>(
2363    request: T,
2364    response: Response<T>,
2365    session: Session,
2366) -> Result<()>
2367where
2368    T: EntityMessage + RequestMessage,
2369{
2370    let project_id = ProjectId::from_proto(request.remote_entity_id());
2371
2372    let host_connection_id = session
2373        .db()
2374        .await
2375        .host_for_mutating_project_request(project_id, session.connection_id)
2376        .await?;
2377    let payload = session
2378        .peer
2379        .forward_request(session.connection_id, host_connection_id, request)
2380        .await?;
2381    response.send(payload)?;
2382    Ok(())
2383}
2384
2385async fn multi_lsp_query(
2386    request: MultiLspQuery,
2387    response: Response<MultiLspQuery>,
2388    session: Session,
2389) -> Result<()> {
2390    tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2391    tracing::info!("multi_lsp_query message received");
2392    forward_mutating_project_request(request, response, session).await
2393}
2394
2395/// Notify other participants that a new buffer has been created
2396async fn create_buffer_for_peer(
2397    request: proto::CreateBufferForPeer,
2398    session: Session,
2399) -> Result<()> {
2400    session
2401        .db()
2402        .await
2403        .check_user_is_project_host(
2404            ProjectId::from_proto(request.project_id),
2405            session.connection_id,
2406        )
2407        .await?;
2408    let peer_id = request.peer_id.context("invalid peer id")?;
2409    session
2410        .peer
2411        .forward_send(session.connection_id, peer_id.into(), request)?;
2412    Ok(())
2413}
2414
2415/// Notify other participants that a buffer has been updated. This is
2416/// allowed for guests as long as the update is limited to selections.
2417async fn update_buffer(
2418    request: proto::UpdateBuffer,
2419    response: Response<proto::UpdateBuffer>,
2420    session: Session,
2421) -> Result<()> {
2422    let project_id = ProjectId::from_proto(request.project_id);
2423    let mut capability = Capability::ReadOnly;
2424
2425    for op in request.operations.iter() {
2426        match op.variant {
2427            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2428            Some(_) => capability = Capability::ReadWrite,
2429        }
2430    }
2431
2432    let host = {
2433        let guard = session
2434            .db()
2435            .await
2436            .connections_for_buffer_update(project_id, session.connection_id, capability)
2437            .await?;
2438
2439        let (host, guests) = &*guard;
2440
2441        broadcast(
2442            Some(session.connection_id),
2443            guests.clone(),
2444            |connection_id| {
2445                session
2446                    .peer
2447                    .forward_send(session.connection_id, connection_id, request.clone())
2448            },
2449        );
2450
2451        *host
2452    };
2453
2454    if host != session.connection_id {
2455        session
2456            .peer
2457            .forward_request(session.connection_id, host, request.clone())
2458            .await?;
2459    }
2460
2461    response.send(proto::Ack {})?;
2462    Ok(())
2463}
2464
2465async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2466    let project_id = ProjectId::from_proto(message.project_id);
2467
2468    let operation = message.operation.as_ref().context("invalid operation")?;
2469    let capability = match operation.variant.as_ref() {
2470        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2471            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2472                match buffer_op.variant {
2473                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2474                        Capability::ReadOnly
2475                    }
2476                    _ => Capability::ReadWrite,
2477                }
2478            } else {
2479                Capability::ReadWrite
2480            }
2481        }
2482        Some(_) => Capability::ReadWrite,
2483        None => Capability::ReadOnly,
2484    };
2485
2486    let guard = session
2487        .db()
2488        .await
2489        .connections_for_buffer_update(project_id, session.connection_id, capability)
2490        .await?;
2491
2492    let (host, guests) = &*guard;
2493
2494    broadcast(
2495        Some(session.connection_id),
2496        guests.iter().chain([host]).copied(),
2497        |connection_id| {
2498            session
2499                .peer
2500                .forward_send(session.connection_id, connection_id, message.clone())
2501        },
2502    );
2503
2504    Ok(())
2505}
2506
2507/// Notify other participants that a project has been updated.
2508async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2509    request: T,
2510    session: Session,
2511) -> Result<()> {
2512    let project_id = ProjectId::from_proto(request.remote_entity_id());
2513    let project_connection_ids = session
2514        .db()
2515        .await
2516        .project_connection_ids(project_id, session.connection_id, false)
2517        .await?;
2518
2519    broadcast(
2520        Some(session.connection_id),
2521        project_connection_ids.iter().copied(),
2522        |connection_id| {
2523            session
2524                .peer
2525                .forward_send(session.connection_id, connection_id, request.clone())
2526        },
2527    );
2528    Ok(())
2529}
2530
2531/// Start following another user in a call.
2532async fn follow(
2533    request: proto::Follow,
2534    response: Response<proto::Follow>,
2535    session: Session,
2536) -> Result<()> {
2537    let room_id = RoomId::from_proto(request.room_id);
2538    let project_id = request.project_id.map(ProjectId::from_proto);
2539    let leader_id = request.leader_id.context("invalid leader id")?.into();
2540    let follower_id = session.connection_id;
2541
2542    session
2543        .db()
2544        .await
2545        .check_room_participants(room_id, leader_id, session.connection_id)
2546        .await?;
2547
2548    let response_payload = session
2549        .peer
2550        .forward_request(session.connection_id, leader_id, request)
2551        .await?;
2552    response.send(response_payload)?;
2553
2554    if let Some(project_id) = project_id {
2555        let room = session
2556            .db()
2557            .await
2558            .follow(room_id, project_id, leader_id, follower_id)
2559            .await?;
2560        room_updated(&room, &session.peer);
2561    }
2562
2563    Ok(())
2564}
2565
2566/// Stop following another user in a call.
2567async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2568    let room_id = RoomId::from_proto(request.room_id);
2569    let project_id = request.project_id.map(ProjectId::from_proto);
2570    let leader_id = request.leader_id.context("invalid leader id")?.into();
2571    let follower_id = session.connection_id;
2572
2573    session
2574        .db()
2575        .await
2576        .check_room_participants(room_id, leader_id, session.connection_id)
2577        .await?;
2578
2579    session
2580        .peer
2581        .forward_send(session.connection_id, leader_id, request)?;
2582
2583    if let Some(project_id) = project_id {
2584        let room = session
2585            .db()
2586            .await
2587            .unfollow(room_id, project_id, leader_id, follower_id)
2588            .await?;
2589        room_updated(&room, &session.peer);
2590    }
2591
2592    Ok(())
2593}
2594
2595/// Notify everyone following you of your current location.
2596async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2597    let room_id = RoomId::from_proto(request.room_id);
2598    let database = session.db.lock().await;
2599
2600    let connection_ids = if let Some(project_id) = request.project_id {
2601        let project_id = ProjectId::from_proto(project_id);
2602        database
2603            .project_connection_ids(project_id, session.connection_id, true)
2604            .await?
2605    } else {
2606        database
2607            .room_connection_ids(room_id, session.connection_id)
2608            .await?
2609    };
2610
2611    // For now, don't send view update messages back to that view's current leader.
2612    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2613        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2614        _ => None,
2615    });
2616
2617    for connection_id in connection_ids.iter().cloned() {
2618        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2619            session
2620                .peer
2621                .forward_send(session.connection_id, connection_id, request.clone())?;
2622        }
2623    }
2624    Ok(())
2625}
2626
2627/// Get public data about users.
2628async fn get_users(
2629    request: proto::GetUsers,
2630    response: Response<proto::GetUsers>,
2631    session: Session,
2632) -> Result<()> {
2633    let user_ids = request
2634        .user_ids
2635        .into_iter()
2636        .map(UserId::from_proto)
2637        .collect();
2638    let users = session
2639        .db()
2640        .await
2641        .get_users_by_ids(user_ids)
2642        .await?
2643        .into_iter()
2644        .map(|user| proto::User {
2645            id: user.id.to_proto(),
2646            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2647            github_login: user.github_login,
2648            name: user.name,
2649        })
2650        .collect();
2651    response.send(proto::UsersResponse { users })?;
2652    Ok(())
2653}
2654
2655/// Search for users (to invite) buy Github login
2656async fn fuzzy_search_users(
2657    request: proto::FuzzySearchUsers,
2658    response: Response<proto::FuzzySearchUsers>,
2659    session: Session,
2660) -> Result<()> {
2661    let query = request.query;
2662    let users = match query.len() {
2663        0 => vec![],
2664        1 | 2 => session
2665            .db()
2666            .await
2667            .get_user_by_github_login(&query)
2668            .await?
2669            .into_iter()
2670            .collect(),
2671        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2672    };
2673    let users = users
2674        .into_iter()
2675        .filter(|user| user.id != session.user_id())
2676        .map(|user| proto::User {
2677            id: user.id.to_proto(),
2678            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2679            github_login: user.github_login,
2680            name: user.name,
2681        })
2682        .collect();
2683    response.send(proto::UsersResponse { users })?;
2684    Ok(())
2685}
2686
2687/// Send a contact request to another user.
2688async fn request_contact(
2689    request: proto::RequestContact,
2690    response: Response<proto::RequestContact>,
2691    session: Session,
2692) -> Result<()> {
2693    let requester_id = session.user_id();
2694    let responder_id = UserId::from_proto(request.responder_id);
2695    if requester_id == responder_id {
2696        return Err(anyhow!("cannot add yourself as a contact"))?;
2697    }
2698
2699    let notifications = session
2700        .db()
2701        .await
2702        .send_contact_request(requester_id, responder_id)
2703        .await?;
2704
2705    // Update outgoing contact requests of requester
2706    let mut update = proto::UpdateContacts::default();
2707    update.outgoing_requests.push(responder_id.to_proto());
2708    for connection_id in session
2709        .connection_pool()
2710        .await
2711        .user_connection_ids(requester_id)
2712    {
2713        session.peer.send(connection_id, update.clone())?;
2714    }
2715
2716    // Update incoming contact requests of responder
2717    let mut update = proto::UpdateContacts::default();
2718    update
2719        .incoming_requests
2720        .push(proto::IncomingContactRequest {
2721            requester_id: requester_id.to_proto(),
2722        });
2723    let connection_pool = session.connection_pool().await;
2724    for connection_id in connection_pool.user_connection_ids(responder_id) {
2725        session.peer.send(connection_id, update.clone())?;
2726    }
2727
2728    send_notifications(&connection_pool, &session.peer, notifications);
2729
2730    response.send(proto::Ack {})?;
2731    Ok(())
2732}
2733
2734/// Accept or decline a contact request
2735async fn respond_to_contact_request(
2736    request: proto::RespondToContactRequest,
2737    response: Response<proto::RespondToContactRequest>,
2738    session: Session,
2739) -> Result<()> {
2740    let responder_id = session.user_id();
2741    let requester_id = UserId::from_proto(request.requester_id);
2742    let db = session.db().await;
2743    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2744        db.dismiss_contact_notification(responder_id, requester_id)
2745            .await?;
2746    } else {
2747        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2748
2749        let notifications = db
2750            .respond_to_contact_request(responder_id, requester_id, accept)
2751            .await?;
2752        let requester_busy = db.is_user_busy(requester_id).await?;
2753        let responder_busy = db.is_user_busy(responder_id).await?;
2754
2755        let pool = session.connection_pool().await;
2756        // Update responder with new contact
2757        let mut update = proto::UpdateContacts::default();
2758        if accept {
2759            update
2760                .contacts
2761                .push(contact_for_user(requester_id, requester_busy, &pool));
2762        }
2763        update
2764            .remove_incoming_requests
2765            .push(requester_id.to_proto());
2766        for connection_id in pool.user_connection_ids(responder_id) {
2767            session.peer.send(connection_id, update.clone())?;
2768        }
2769
2770        // Update requester with new contact
2771        let mut update = proto::UpdateContacts::default();
2772        if accept {
2773            update
2774                .contacts
2775                .push(contact_for_user(responder_id, responder_busy, &pool));
2776        }
2777        update
2778            .remove_outgoing_requests
2779            .push(responder_id.to_proto());
2780
2781        for connection_id in pool.user_connection_ids(requester_id) {
2782            session.peer.send(connection_id, update.clone())?;
2783        }
2784
2785        send_notifications(&pool, &session.peer, notifications);
2786    }
2787
2788    response.send(proto::Ack {})?;
2789    Ok(())
2790}
2791
2792/// Remove a contact.
2793async fn remove_contact(
2794    request: proto::RemoveContact,
2795    response: Response<proto::RemoveContact>,
2796    session: Session,
2797) -> Result<()> {
2798    let requester_id = session.user_id();
2799    let responder_id = UserId::from_proto(request.user_id);
2800    let db = session.db().await;
2801    let (contact_accepted, deleted_notification_id) =
2802        db.remove_contact(requester_id, responder_id).await?;
2803
2804    let pool = session.connection_pool().await;
2805    // Update outgoing contact requests of requester
2806    let mut update = proto::UpdateContacts::default();
2807    if contact_accepted {
2808        update.remove_contacts.push(responder_id.to_proto());
2809    } else {
2810        update
2811            .remove_outgoing_requests
2812            .push(responder_id.to_proto());
2813    }
2814    for connection_id in pool.user_connection_ids(requester_id) {
2815        session.peer.send(connection_id, update.clone())?;
2816    }
2817
2818    // Update incoming contact requests of responder
2819    let mut update = proto::UpdateContacts::default();
2820    if contact_accepted {
2821        update.remove_contacts.push(requester_id.to_proto());
2822    } else {
2823        update
2824            .remove_incoming_requests
2825            .push(requester_id.to_proto());
2826    }
2827    for connection_id in pool.user_connection_ids(responder_id) {
2828        session.peer.send(connection_id, update.clone())?;
2829        if let Some(notification_id) = deleted_notification_id {
2830            session.peer.send(
2831                connection_id,
2832                proto::DeleteNotification {
2833                    notification_id: notification_id.to_proto(),
2834                },
2835            )?;
2836        }
2837    }
2838
2839    response.send(proto::Ack {})?;
2840    Ok(())
2841}
2842
2843fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2844    version.0.minor() < 139
2845}
2846
2847async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2848    if is_staff {
2849        return Ok(proto::Plan::ZedPro);
2850    }
2851
2852    let subscription = db.get_active_billing_subscription(user_id).await?;
2853    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2854
2855    let plan = if let Some(subscription_kind) = subscription_kind {
2856        match subscription_kind {
2857            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2858            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2859            SubscriptionKind::ZedFree => proto::Plan::Free,
2860        }
2861    } else {
2862        proto::Plan::Free
2863    };
2864
2865    Ok(plan)
2866}
2867
2868async fn make_update_user_plan_message(
2869    user: &User,
2870    is_staff: bool,
2871    db: &Arc<Database>,
2872    llm_db: Option<Arc<LlmDatabase>>,
2873) -> Result<proto::UpdateUserPlan> {
2874    let feature_flags = db.get_user_flags(user.id).await?;
2875    let plan = current_plan(db, user.id, is_staff).await?;
2876    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2877    let billing_preferences = db.get_billing_preferences(user.id).await?;
2878
2879    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2880        let subscription = db.get_active_billing_subscription(user.id).await?;
2881
2882        let subscription_period =
2883            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2884
2885        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2886            llm_db
2887                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2888                .await?
2889        } else {
2890            None
2891        };
2892
2893        (subscription_period, usage)
2894    } else {
2895        (None, None)
2896    };
2897
2898    let bypass_account_age_check = feature_flags
2899        .iter()
2900        .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2901    let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2902        && !bypass_account_age_check
2903        && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2904
2905    Ok(proto::UpdateUserPlan {
2906        plan: plan.into(),
2907        trial_started_at: billing_customer
2908            .as_ref()
2909            .and_then(|billing_customer| billing_customer.trial_started_at)
2910            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2911        is_usage_based_billing_enabled: if is_staff {
2912            Some(true)
2913        } else {
2914            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2915        },
2916        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2917            proto::SubscriptionPeriod {
2918                started_at: started_at.timestamp() as u64,
2919                ended_at: ended_at.timestamp() as u64,
2920            }
2921        }),
2922        account_too_young: Some(account_too_young),
2923        has_overdue_invoices: billing_customer
2924            .map(|billing_customer| billing_customer.has_overdue_invoices),
2925        usage: Some(
2926            usage
2927                .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2928                .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2929        ),
2930    })
2931}
2932
2933fn model_requests_limit(
2934    plan: cloud_llm_client::Plan,
2935    feature_flags: &Vec<String>,
2936) -> cloud_llm_client::UsageLimit {
2937    match plan.model_requests_limit() {
2938        cloud_llm_client::UsageLimit::Limited(limit) => {
2939            let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2940                && feature_flags
2941                    .iter()
2942                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2943            {
2944                1_000
2945            } else {
2946                limit
2947            };
2948
2949            cloud_llm_client::UsageLimit::Limited(limit)
2950        }
2951        cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2952    }
2953}
2954
2955fn subscription_usage_to_proto(
2956    plan: proto::Plan,
2957    usage: crate::llm::db::subscription_usage::Model,
2958    feature_flags: &Vec<String>,
2959) -> proto::SubscriptionUsage {
2960    let plan = match plan {
2961        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2962        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2963        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2964    };
2965
2966    proto::SubscriptionUsage {
2967        model_requests_usage_amount: usage.model_requests as u32,
2968        model_requests_usage_limit: Some(proto::UsageLimit {
2969            variant: Some(match model_requests_limit(plan, feature_flags) {
2970                cloud_llm_client::UsageLimit::Limited(limit) => {
2971                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2972                        limit: limit as u32,
2973                    })
2974                }
2975                cloud_llm_client::UsageLimit::Unlimited => {
2976                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2977                }
2978            }),
2979        }),
2980        edit_predictions_usage_amount: usage.edit_predictions as u32,
2981        edit_predictions_usage_limit: Some(proto::UsageLimit {
2982            variant: Some(match plan.edit_predictions_limit() {
2983                cloud_llm_client::UsageLimit::Limited(limit) => {
2984                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2985                        limit: limit as u32,
2986                    })
2987                }
2988                cloud_llm_client::UsageLimit::Unlimited => {
2989                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2990                }
2991            }),
2992        }),
2993    }
2994}
2995
2996fn make_default_subscription_usage(
2997    plan: proto::Plan,
2998    feature_flags: &Vec<String>,
2999) -> proto::SubscriptionUsage {
3000    let plan = match plan {
3001        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
3002        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
3003        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
3004    };
3005
3006    proto::SubscriptionUsage {
3007        model_requests_usage_amount: 0,
3008        model_requests_usage_limit: Some(proto::UsageLimit {
3009            variant: Some(match model_requests_limit(plan, feature_flags) {
3010                cloud_llm_client::UsageLimit::Limited(limit) => {
3011                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3012                        limit: limit as u32,
3013                    })
3014                }
3015                cloud_llm_client::UsageLimit::Unlimited => {
3016                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3017                }
3018            }),
3019        }),
3020        edit_predictions_usage_amount: 0,
3021        edit_predictions_usage_limit: Some(proto::UsageLimit {
3022            variant: Some(match plan.edit_predictions_limit() {
3023                cloud_llm_client::UsageLimit::Limited(limit) => {
3024                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3025                        limit: limit as u32,
3026                    })
3027                }
3028                cloud_llm_client::UsageLimit::Unlimited => {
3029                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3030                }
3031            }),
3032        }),
3033    }
3034}
3035
3036async fn update_user_plan(session: &Session) -> Result<()> {
3037    let db = session.db().await;
3038
3039    let update_user_plan = make_update_user_plan_message(
3040        session.principal.user(),
3041        session.is_staff(),
3042        &db.0,
3043        session.app_state.llm_db.clone(),
3044    )
3045    .await?;
3046
3047    session
3048        .peer
3049        .send(session.connection_id, update_user_plan)
3050        .trace_err();
3051
3052    Ok(())
3053}
3054
3055async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3056    subscribe_user_to_channels(session.user_id(), &session).await?;
3057    Ok(())
3058}
3059
3060async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3061    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3062    let mut pool = session.connection_pool().await;
3063    for membership in &channels_for_user.channel_memberships {
3064        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3065    }
3066    session.peer.send(
3067        session.connection_id,
3068        build_update_user_channels(&channels_for_user),
3069    )?;
3070    session.peer.send(
3071        session.connection_id,
3072        build_channels_update(channels_for_user),
3073    )?;
3074    Ok(())
3075}
3076
3077/// Creates a new channel.
3078async fn create_channel(
3079    request: proto::CreateChannel,
3080    response: Response<proto::CreateChannel>,
3081    session: Session,
3082) -> Result<()> {
3083    let db = session.db().await;
3084
3085    let parent_id = request.parent_id.map(ChannelId::from_proto);
3086    let (channel, membership) = db
3087        .create_channel(&request.name, parent_id, session.user_id())
3088        .await?;
3089
3090    let root_id = channel.root_id();
3091    let channel = Channel::from_model(channel);
3092
3093    response.send(proto::CreateChannelResponse {
3094        channel: Some(channel.to_proto()),
3095        parent_id: request.parent_id,
3096    })?;
3097
3098    let mut connection_pool = session.connection_pool().await;
3099    if let Some(membership) = membership {
3100        connection_pool.subscribe_to_channel(
3101            membership.user_id,
3102            membership.channel_id,
3103            membership.role,
3104        );
3105        let update = proto::UpdateUserChannels {
3106            channel_memberships: vec![proto::ChannelMembership {
3107                channel_id: membership.channel_id.to_proto(),
3108                role: membership.role.into(),
3109            }],
3110            ..Default::default()
3111        };
3112        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3113            session.peer.send(connection_id, update.clone())?;
3114        }
3115    }
3116
3117    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3118        if !role.can_see_channel(channel.visibility) {
3119            continue;
3120        }
3121
3122        let update = proto::UpdateChannels {
3123            channels: vec![channel.to_proto()],
3124            ..Default::default()
3125        };
3126        session.peer.send(connection_id, update.clone())?;
3127    }
3128
3129    Ok(())
3130}
3131
3132/// Delete a channel
3133async fn delete_channel(
3134    request: proto::DeleteChannel,
3135    response: Response<proto::DeleteChannel>,
3136    session: Session,
3137) -> Result<()> {
3138    let db = session.db().await;
3139
3140    let channel_id = request.channel_id;
3141    let (root_channel, removed_channels) = db
3142        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3143        .await?;
3144    response.send(proto::Ack {})?;
3145
3146    // Notify members of removed channels
3147    let mut update = proto::UpdateChannels::default();
3148    update
3149        .delete_channels
3150        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3151
3152    let connection_pool = session.connection_pool().await;
3153    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3154        session.peer.send(connection_id, update.clone())?;
3155    }
3156
3157    Ok(())
3158}
3159
3160/// Invite someone to join a channel.
3161async fn invite_channel_member(
3162    request: proto::InviteChannelMember,
3163    response: Response<proto::InviteChannelMember>,
3164    session: Session,
3165) -> Result<()> {
3166    let db = session.db().await;
3167    let channel_id = ChannelId::from_proto(request.channel_id);
3168    let invitee_id = UserId::from_proto(request.user_id);
3169    let InviteMemberResult {
3170        channel,
3171        notifications,
3172    } = db
3173        .invite_channel_member(
3174            channel_id,
3175            invitee_id,
3176            session.user_id(),
3177            request.role().into(),
3178        )
3179        .await?;
3180
3181    let update = proto::UpdateChannels {
3182        channel_invitations: vec![channel.to_proto()],
3183        ..Default::default()
3184    };
3185
3186    let connection_pool = session.connection_pool().await;
3187    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3188        session.peer.send(connection_id, update.clone())?;
3189    }
3190
3191    send_notifications(&connection_pool, &session.peer, notifications);
3192
3193    response.send(proto::Ack {})?;
3194    Ok(())
3195}
3196
3197/// remove someone from a channel
3198async fn remove_channel_member(
3199    request: proto::RemoveChannelMember,
3200    response: Response<proto::RemoveChannelMember>,
3201    session: Session,
3202) -> Result<()> {
3203    let db = session.db().await;
3204    let channel_id = ChannelId::from_proto(request.channel_id);
3205    let member_id = UserId::from_proto(request.user_id);
3206
3207    let RemoveChannelMemberResult {
3208        membership_update,
3209        notification_id,
3210    } = db
3211        .remove_channel_member(channel_id, member_id, session.user_id())
3212        .await?;
3213
3214    let mut connection_pool = session.connection_pool().await;
3215    notify_membership_updated(
3216        &mut connection_pool,
3217        membership_update,
3218        member_id,
3219        &session.peer,
3220    );
3221    for connection_id in connection_pool.user_connection_ids(member_id) {
3222        if let Some(notification_id) = notification_id {
3223            session
3224                .peer
3225                .send(
3226                    connection_id,
3227                    proto::DeleteNotification {
3228                        notification_id: notification_id.to_proto(),
3229                    },
3230                )
3231                .trace_err();
3232        }
3233    }
3234
3235    response.send(proto::Ack {})?;
3236    Ok(())
3237}
3238
3239/// Toggle the channel between public and private.
3240/// Care is taken to maintain the invariant that public channels only descend from public channels,
3241/// (though members-only channels can appear at any point in the hierarchy).
3242async fn set_channel_visibility(
3243    request: proto::SetChannelVisibility,
3244    response: Response<proto::SetChannelVisibility>,
3245    session: Session,
3246) -> Result<()> {
3247    let db = session.db().await;
3248    let channel_id = ChannelId::from_proto(request.channel_id);
3249    let visibility = request.visibility().into();
3250
3251    let channel_model = db
3252        .set_channel_visibility(channel_id, visibility, session.user_id())
3253        .await?;
3254    let root_id = channel_model.root_id();
3255    let channel = Channel::from_model(channel_model);
3256
3257    let mut connection_pool = session.connection_pool().await;
3258    for (user_id, role) in connection_pool
3259        .channel_user_ids(root_id)
3260        .collect::<Vec<_>>()
3261        .into_iter()
3262    {
3263        let update = if role.can_see_channel(channel.visibility) {
3264            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3265            proto::UpdateChannels {
3266                channels: vec![channel.to_proto()],
3267                ..Default::default()
3268            }
3269        } else {
3270            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3271            proto::UpdateChannels {
3272                delete_channels: vec![channel.id.to_proto()],
3273                ..Default::default()
3274            }
3275        };
3276
3277        for connection_id in connection_pool.user_connection_ids(user_id) {
3278            session.peer.send(connection_id, update.clone())?;
3279        }
3280    }
3281
3282    response.send(proto::Ack {})?;
3283    Ok(())
3284}
3285
3286/// Alter the role for a user in the channel.
3287async fn set_channel_member_role(
3288    request: proto::SetChannelMemberRole,
3289    response: Response<proto::SetChannelMemberRole>,
3290    session: Session,
3291) -> Result<()> {
3292    let db = session.db().await;
3293    let channel_id = ChannelId::from_proto(request.channel_id);
3294    let member_id = UserId::from_proto(request.user_id);
3295    let result = db
3296        .set_channel_member_role(
3297            channel_id,
3298            session.user_id(),
3299            member_id,
3300            request.role().into(),
3301        )
3302        .await?;
3303
3304    match result {
3305        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3306            let mut connection_pool = session.connection_pool().await;
3307            notify_membership_updated(
3308                &mut connection_pool,
3309                membership_update,
3310                member_id,
3311                &session.peer,
3312            )
3313        }
3314        db::SetMemberRoleResult::InviteUpdated(channel) => {
3315            let update = proto::UpdateChannels {
3316                channel_invitations: vec![channel.to_proto()],
3317                ..Default::default()
3318            };
3319
3320            for connection_id in session
3321                .connection_pool()
3322                .await
3323                .user_connection_ids(member_id)
3324            {
3325                session.peer.send(connection_id, update.clone())?;
3326            }
3327        }
3328    }
3329
3330    response.send(proto::Ack {})?;
3331    Ok(())
3332}
3333
3334/// Change the name of a channel
3335async fn rename_channel(
3336    request: proto::RenameChannel,
3337    response: Response<proto::RenameChannel>,
3338    session: Session,
3339) -> Result<()> {
3340    let db = session.db().await;
3341    let channel_id = ChannelId::from_proto(request.channel_id);
3342    let channel_model = db
3343        .rename_channel(channel_id, session.user_id(), &request.name)
3344        .await?;
3345    let root_id = channel_model.root_id();
3346    let channel = Channel::from_model(channel_model);
3347
3348    response.send(proto::RenameChannelResponse {
3349        channel: Some(channel.to_proto()),
3350    })?;
3351
3352    let connection_pool = session.connection_pool().await;
3353    let update = proto::UpdateChannels {
3354        channels: vec![channel.to_proto()],
3355        ..Default::default()
3356    };
3357    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3358        if role.can_see_channel(channel.visibility) {
3359            session.peer.send(connection_id, update.clone())?;
3360        }
3361    }
3362
3363    Ok(())
3364}
3365
3366/// Move a channel to a new parent.
3367async fn move_channel(
3368    request: proto::MoveChannel,
3369    response: Response<proto::MoveChannel>,
3370    session: Session,
3371) -> Result<()> {
3372    let channel_id = ChannelId::from_proto(request.channel_id);
3373    let to = ChannelId::from_proto(request.to);
3374
3375    let (root_id, channels) = session
3376        .db()
3377        .await
3378        .move_channel(channel_id, to, session.user_id())
3379        .await?;
3380
3381    let connection_pool = session.connection_pool().await;
3382    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3383        let channels = channels
3384            .iter()
3385            .filter_map(|channel| {
3386                if role.can_see_channel(channel.visibility) {
3387                    Some(channel.to_proto())
3388                } else {
3389                    None
3390                }
3391            })
3392            .collect::<Vec<_>>();
3393        if channels.is_empty() {
3394            continue;
3395        }
3396
3397        let update = proto::UpdateChannels {
3398            channels,
3399            ..Default::default()
3400        };
3401
3402        session.peer.send(connection_id, update.clone())?;
3403    }
3404
3405    response.send(Ack {})?;
3406    Ok(())
3407}
3408
3409async fn reorder_channel(
3410    request: proto::ReorderChannel,
3411    response: Response<proto::ReorderChannel>,
3412    session: Session,
3413) -> Result<()> {
3414    let channel_id = ChannelId::from_proto(request.channel_id);
3415    let direction = request.direction();
3416
3417    let updated_channels = session
3418        .db()
3419        .await
3420        .reorder_channel(channel_id, direction, session.user_id())
3421        .await?;
3422
3423    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3424        let connection_pool = session.connection_pool().await;
3425        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3426            let channels = updated_channels
3427                .iter()
3428                .filter_map(|channel| {
3429                    if role.can_see_channel(channel.visibility) {
3430                        Some(channel.to_proto())
3431                    } else {
3432                        None
3433                    }
3434                })
3435                .collect::<Vec<_>>();
3436
3437            if channels.is_empty() {
3438                continue;
3439            }
3440
3441            let update = proto::UpdateChannels {
3442                channels,
3443                ..Default::default()
3444            };
3445
3446            session.peer.send(connection_id, update.clone())?;
3447        }
3448    }
3449
3450    response.send(Ack {})?;
3451    Ok(())
3452}
3453
3454/// Get the list of channel members
3455async fn get_channel_members(
3456    request: proto::GetChannelMembers,
3457    response: Response<proto::GetChannelMembers>,
3458    session: Session,
3459) -> Result<()> {
3460    let db = session.db().await;
3461    let channel_id = ChannelId::from_proto(request.channel_id);
3462    let limit = if request.limit == 0 {
3463        u16::MAX as u64
3464    } else {
3465        request.limit
3466    };
3467    let (members, users) = db
3468        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3469        .await?;
3470    response.send(proto::GetChannelMembersResponse { members, users })?;
3471    Ok(())
3472}
3473
3474/// Accept or decline a channel invitation.
3475async fn respond_to_channel_invite(
3476    request: proto::RespondToChannelInvite,
3477    response: Response<proto::RespondToChannelInvite>,
3478    session: Session,
3479) -> Result<()> {
3480    let db = session.db().await;
3481    let channel_id = ChannelId::from_proto(request.channel_id);
3482    let RespondToChannelInvite {
3483        membership_update,
3484        notifications,
3485    } = db
3486        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3487        .await?;
3488
3489    let mut connection_pool = session.connection_pool().await;
3490    if let Some(membership_update) = membership_update {
3491        notify_membership_updated(
3492            &mut connection_pool,
3493            membership_update,
3494            session.user_id(),
3495            &session.peer,
3496        );
3497    } else {
3498        let update = proto::UpdateChannels {
3499            remove_channel_invitations: vec![channel_id.to_proto()],
3500            ..Default::default()
3501        };
3502
3503        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3504            session.peer.send(connection_id, update.clone())?;
3505        }
3506    };
3507
3508    send_notifications(&connection_pool, &session.peer, notifications);
3509
3510    response.send(proto::Ack {})?;
3511
3512    Ok(())
3513}
3514
3515/// Join the channels' room
3516async fn join_channel(
3517    request: proto::JoinChannel,
3518    response: Response<proto::JoinChannel>,
3519    session: Session,
3520) -> Result<()> {
3521    let channel_id = ChannelId::from_proto(request.channel_id);
3522    join_channel_internal(channel_id, Box::new(response), session).await
3523}
3524
3525trait JoinChannelInternalResponse {
3526    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3527}
3528impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3529    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3530        Response::<proto::JoinChannel>::send(self, result)
3531    }
3532}
3533impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3534    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3535        Response::<proto::JoinRoom>::send(self, result)
3536    }
3537}
3538
3539async fn join_channel_internal(
3540    channel_id: ChannelId,
3541    response: Box<impl JoinChannelInternalResponse>,
3542    session: Session,
3543) -> Result<()> {
3544    let joined_room = {
3545        let mut db = session.db().await;
3546        // If zed quits without leaving the room, and the user re-opens zed before the
3547        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3548        // room they were in.
3549        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3550            tracing::info!(
3551                stale_connection_id = %connection,
3552                "cleaning up stale connection",
3553            );
3554            drop(db);
3555            leave_room_for_session(&session, connection).await?;
3556            db = session.db().await;
3557        }
3558
3559        let (joined_room, membership_updated, role) = db
3560            .join_channel(channel_id, session.user_id(), session.connection_id)
3561            .await?;
3562
3563        let live_kit_connection_info =
3564            session
3565                .app_state
3566                .livekit_client
3567                .as_ref()
3568                .and_then(|live_kit| {
3569                    let (can_publish, token) = if role == ChannelRole::Guest {
3570                        (
3571                            false,
3572                            live_kit
3573                                .guest_token(
3574                                    &joined_room.room.livekit_room,
3575                                    &session.user_id().to_string(),
3576                                )
3577                                .trace_err()?,
3578                        )
3579                    } else {
3580                        (
3581                            true,
3582                            live_kit
3583                                .room_token(
3584                                    &joined_room.room.livekit_room,
3585                                    &session.user_id().to_string(),
3586                                )
3587                                .trace_err()?,
3588                        )
3589                    };
3590
3591                    Some(LiveKitConnectionInfo {
3592                        server_url: live_kit.url().into(),
3593                        token,
3594                        can_publish,
3595                    })
3596                });
3597
3598        response.send(proto::JoinRoomResponse {
3599            room: Some(joined_room.room.clone()),
3600            channel_id: joined_room
3601                .channel
3602                .as_ref()
3603                .map(|channel| channel.id.to_proto()),
3604            live_kit_connection_info,
3605        })?;
3606
3607        let mut connection_pool = session.connection_pool().await;
3608        if let Some(membership_updated) = membership_updated {
3609            notify_membership_updated(
3610                &mut connection_pool,
3611                membership_updated,
3612                session.user_id(),
3613                &session.peer,
3614            );
3615        }
3616
3617        room_updated(&joined_room.room, &session.peer);
3618
3619        joined_room
3620    };
3621
3622    channel_updated(
3623        &joined_room.channel.context("channel not returned")?,
3624        &joined_room.room,
3625        &session.peer,
3626        &*session.connection_pool().await,
3627    );
3628
3629    update_user_contacts(session.user_id(), &session).await?;
3630    Ok(())
3631}
3632
3633/// Start editing the channel notes
3634async fn join_channel_buffer(
3635    request: proto::JoinChannelBuffer,
3636    response: Response<proto::JoinChannelBuffer>,
3637    session: Session,
3638) -> Result<()> {
3639    let db = session.db().await;
3640    let channel_id = ChannelId::from_proto(request.channel_id);
3641
3642    let open_response = db
3643        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3644        .await?;
3645
3646    let collaborators = open_response.collaborators.clone();
3647    response.send(open_response)?;
3648
3649    let update = UpdateChannelBufferCollaborators {
3650        channel_id: channel_id.to_proto(),
3651        collaborators: collaborators.clone(),
3652    };
3653    channel_buffer_updated(
3654        session.connection_id,
3655        collaborators
3656            .iter()
3657            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3658        &update,
3659        &session.peer,
3660    );
3661
3662    Ok(())
3663}
3664
3665/// Edit the channel notes
3666async fn update_channel_buffer(
3667    request: proto::UpdateChannelBuffer,
3668    session: Session,
3669) -> Result<()> {
3670    let db = session.db().await;
3671    let channel_id = ChannelId::from_proto(request.channel_id);
3672
3673    let (collaborators, epoch, version) = db
3674        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3675        .await?;
3676
3677    channel_buffer_updated(
3678        session.connection_id,
3679        collaborators.clone(),
3680        &proto::UpdateChannelBuffer {
3681            channel_id: channel_id.to_proto(),
3682            operations: request.operations,
3683        },
3684        &session.peer,
3685    );
3686
3687    let pool = &*session.connection_pool().await;
3688
3689    let non_collaborators =
3690        pool.channel_connection_ids(channel_id)
3691            .filter_map(|(connection_id, _)| {
3692                if collaborators.contains(&connection_id) {
3693                    None
3694                } else {
3695                    Some(connection_id)
3696                }
3697            });
3698
3699    broadcast(None, non_collaborators, |peer_id| {
3700        session.peer.send(
3701            peer_id,
3702            proto::UpdateChannels {
3703                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3704                    channel_id: channel_id.to_proto(),
3705                    epoch: epoch as u64,
3706                    version: version.clone(),
3707                }],
3708                ..Default::default()
3709            },
3710        )
3711    });
3712
3713    Ok(())
3714}
3715
3716/// Rejoin the channel notes after a connection blip
3717async fn rejoin_channel_buffers(
3718    request: proto::RejoinChannelBuffers,
3719    response: Response<proto::RejoinChannelBuffers>,
3720    session: Session,
3721) -> Result<()> {
3722    let db = session.db().await;
3723    let buffers = db
3724        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3725        .await?;
3726
3727    for rejoined_buffer in &buffers {
3728        let collaborators_to_notify = rejoined_buffer
3729            .buffer
3730            .collaborators
3731            .iter()
3732            .filter_map(|c| Some(c.peer_id?.into()));
3733        channel_buffer_updated(
3734            session.connection_id,
3735            collaborators_to_notify,
3736            &proto::UpdateChannelBufferCollaborators {
3737                channel_id: rejoined_buffer.buffer.channel_id,
3738                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3739            },
3740            &session.peer,
3741        );
3742    }
3743
3744    response.send(proto::RejoinChannelBuffersResponse {
3745        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3746    })?;
3747
3748    Ok(())
3749}
3750
3751/// Stop editing the channel notes
3752async fn leave_channel_buffer(
3753    request: proto::LeaveChannelBuffer,
3754    response: Response<proto::LeaveChannelBuffer>,
3755    session: Session,
3756) -> Result<()> {
3757    let db = session.db().await;
3758    let channel_id = ChannelId::from_proto(request.channel_id);
3759
3760    let left_buffer = db
3761        .leave_channel_buffer(channel_id, session.connection_id)
3762        .await?;
3763
3764    response.send(Ack {})?;
3765
3766    channel_buffer_updated(
3767        session.connection_id,
3768        left_buffer.connections,
3769        &proto::UpdateChannelBufferCollaborators {
3770            channel_id: channel_id.to_proto(),
3771            collaborators: left_buffer.collaborators,
3772        },
3773        &session.peer,
3774    );
3775
3776    Ok(())
3777}
3778
3779fn channel_buffer_updated<T: EnvelopedMessage>(
3780    sender_id: ConnectionId,
3781    collaborators: impl IntoIterator<Item = ConnectionId>,
3782    message: &T,
3783    peer: &Peer,
3784) {
3785    broadcast(Some(sender_id), collaborators, |peer_id| {
3786        peer.send(peer_id, message.clone())
3787    });
3788}
3789
3790fn send_notifications(
3791    connection_pool: &ConnectionPool,
3792    peer: &Peer,
3793    notifications: db::NotificationBatch,
3794) {
3795    for (user_id, notification) in notifications {
3796        for connection_id in connection_pool.user_connection_ids(user_id) {
3797            if let Err(error) = peer.send(
3798                connection_id,
3799                proto::AddNotification {
3800                    notification: Some(notification.clone()),
3801                },
3802            ) {
3803                tracing::error!(
3804                    "failed to send notification to {:?} {}",
3805                    connection_id,
3806                    error
3807                );
3808            }
3809        }
3810    }
3811}
3812
3813/// Send a message to the channel
3814async fn send_channel_message(
3815    request: proto::SendChannelMessage,
3816    response: Response<proto::SendChannelMessage>,
3817    session: Session,
3818) -> Result<()> {
3819    // Validate the message body.
3820    let body = request.body.trim().to_string();
3821    if body.len() > MAX_MESSAGE_LEN {
3822        return Err(anyhow!("message is too long"))?;
3823    }
3824    if body.is_empty() {
3825        return Err(anyhow!("message can't be blank"))?;
3826    }
3827
3828    // TODO: adjust mentions if body is trimmed
3829
3830    let timestamp = OffsetDateTime::now_utc();
3831    let nonce = request.nonce.context("nonce can't be blank")?;
3832
3833    let channel_id = ChannelId::from_proto(request.channel_id);
3834    let CreatedChannelMessage {
3835        message_id,
3836        participant_connection_ids,
3837        notifications,
3838    } = session
3839        .db()
3840        .await
3841        .create_channel_message(
3842            channel_id,
3843            session.user_id(),
3844            &body,
3845            &request.mentions,
3846            timestamp,
3847            nonce.clone().into(),
3848            request.reply_to_message_id.map(MessageId::from_proto),
3849        )
3850        .await?;
3851
3852    let message = proto::ChannelMessage {
3853        sender_id: session.user_id().to_proto(),
3854        id: message_id.to_proto(),
3855        body,
3856        mentions: request.mentions,
3857        timestamp: timestamp.unix_timestamp() as u64,
3858        nonce: Some(nonce),
3859        reply_to_message_id: request.reply_to_message_id,
3860        edited_at: None,
3861    };
3862    broadcast(
3863        Some(session.connection_id),
3864        participant_connection_ids.clone(),
3865        |connection| {
3866            session.peer.send(
3867                connection,
3868                proto::ChannelMessageSent {
3869                    channel_id: channel_id.to_proto(),
3870                    message: Some(message.clone()),
3871                },
3872            )
3873        },
3874    );
3875    response.send(proto::SendChannelMessageResponse {
3876        message: Some(message),
3877    })?;
3878
3879    let pool = &*session.connection_pool().await;
3880    let non_participants =
3881        pool.channel_connection_ids(channel_id)
3882            .filter_map(|(connection_id, _)| {
3883                if participant_connection_ids.contains(&connection_id) {
3884                    None
3885                } else {
3886                    Some(connection_id)
3887                }
3888            });
3889    broadcast(None, non_participants, |peer_id| {
3890        session.peer.send(
3891            peer_id,
3892            proto::UpdateChannels {
3893                latest_channel_message_ids: vec![proto::ChannelMessageId {
3894                    channel_id: channel_id.to_proto(),
3895                    message_id: message_id.to_proto(),
3896                }],
3897                ..Default::default()
3898            },
3899        )
3900    });
3901    send_notifications(pool, &session.peer, notifications);
3902
3903    Ok(())
3904}
3905
3906/// Delete a channel message
3907async fn remove_channel_message(
3908    request: proto::RemoveChannelMessage,
3909    response: Response<proto::RemoveChannelMessage>,
3910    session: Session,
3911) -> Result<()> {
3912    let channel_id = ChannelId::from_proto(request.channel_id);
3913    let message_id = MessageId::from_proto(request.message_id);
3914    let (connection_ids, existing_notification_ids) = session
3915        .db()
3916        .await
3917        .remove_channel_message(channel_id, message_id, session.user_id())
3918        .await?;
3919
3920    broadcast(
3921        Some(session.connection_id),
3922        connection_ids,
3923        move |connection| {
3924            session.peer.send(connection, request.clone())?;
3925
3926            for notification_id in &existing_notification_ids {
3927                session.peer.send(
3928                    connection,
3929                    proto::DeleteNotification {
3930                        notification_id: (*notification_id).to_proto(),
3931                    },
3932                )?;
3933            }
3934
3935            Ok(())
3936        },
3937    );
3938    response.send(proto::Ack {})?;
3939    Ok(())
3940}
3941
3942async fn update_channel_message(
3943    request: proto::UpdateChannelMessage,
3944    response: Response<proto::UpdateChannelMessage>,
3945    session: Session,
3946) -> Result<()> {
3947    let channel_id = ChannelId::from_proto(request.channel_id);
3948    let message_id = MessageId::from_proto(request.message_id);
3949    let updated_at = OffsetDateTime::now_utc();
3950    let UpdatedChannelMessage {
3951        message_id,
3952        participant_connection_ids,
3953        notifications,
3954        reply_to_message_id,
3955        timestamp,
3956        deleted_mention_notification_ids,
3957        updated_mention_notifications,
3958    } = session
3959        .db()
3960        .await
3961        .update_channel_message(
3962            channel_id,
3963            message_id,
3964            session.user_id(),
3965            request.body.as_str(),
3966            &request.mentions,
3967            updated_at,
3968        )
3969        .await?;
3970
3971    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3972
3973    let message = proto::ChannelMessage {
3974        sender_id: session.user_id().to_proto(),
3975        id: message_id.to_proto(),
3976        body: request.body.clone(),
3977        mentions: request.mentions.clone(),
3978        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3979        nonce: Some(nonce),
3980        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3981        edited_at: Some(updated_at.unix_timestamp() as u64),
3982    };
3983
3984    response.send(proto::Ack {})?;
3985
3986    let pool = &*session.connection_pool().await;
3987    broadcast(
3988        Some(session.connection_id),
3989        participant_connection_ids,
3990        |connection| {
3991            session.peer.send(
3992                connection,
3993                proto::ChannelMessageUpdate {
3994                    channel_id: channel_id.to_proto(),
3995                    message: Some(message.clone()),
3996                },
3997            )?;
3998
3999            for notification_id in &deleted_mention_notification_ids {
4000                session.peer.send(
4001                    connection,
4002                    proto::DeleteNotification {
4003                        notification_id: (*notification_id).to_proto(),
4004                    },
4005                )?;
4006            }
4007
4008            for notification in &updated_mention_notifications {
4009                session.peer.send(
4010                    connection,
4011                    proto::UpdateNotification {
4012                        notification: Some(notification.clone()),
4013                    },
4014                )?;
4015            }
4016
4017            Ok(())
4018        },
4019    );
4020
4021    send_notifications(pool, &session.peer, notifications);
4022
4023    Ok(())
4024}
4025
4026/// Mark a channel message as read
4027async fn acknowledge_channel_message(
4028    request: proto::AckChannelMessage,
4029    session: Session,
4030) -> Result<()> {
4031    let channel_id = ChannelId::from_proto(request.channel_id);
4032    let message_id = MessageId::from_proto(request.message_id);
4033    let notifications = session
4034        .db()
4035        .await
4036        .observe_channel_message(channel_id, session.user_id(), message_id)
4037        .await?;
4038    send_notifications(
4039        &*session.connection_pool().await,
4040        &session.peer,
4041        notifications,
4042    );
4043    Ok(())
4044}
4045
4046/// Mark a buffer version as synced
4047async fn acknowledge_buffer_version(
4048    request: proto::AckBufferOperation,
4049    session: Session,
4050) -> Result<()> {
4051    let buffer_id = BufferId::from_proto(request.buffer_id);
4052    session
4053        .db()
4054        .await
4055        .observe_buffer_version(
4056            buffer_id,
4057            session.user_id(),
4058            request.epoch as i32,
4059            &request.version,
4060        )
4061        .await?;
4062    Ok(())
4063}
4064
4065/// Get a Supermaven API key for the user
4066async fn get_supermaven_api_key(
4067    _request: proto::GetSupermavenApiKey,
4068    response: Response<proto::GetSupermavenApiKey>,
4069    session: Session,
4070) -> Result<()> {
4071    let user_id: String = session.user_id().to_string();
4072    if !session.is_staff() {
4073        return Err(anyhow!("supermaven not enabled for this account"))?;
4074    }
4075
4076    let email = session.email().context("user must have an email")?;
4077
4078    let supermaven_admin_api = session
4079        .supermaven_client
4080        .as_ref()
4081        .context("supermaven not configured")?;
4082
4083    let result = supermaven_admin_api
4084        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4085        .await?;
4086
4087    response.send(proto::GetSupermavenApiKeyResponse {
4088        api_key: result.api_key,
4089    })?;
4090
4091    Ok(())
4092}
4093
4094/// Start receiving chat updates for a channel
4095async fn join_channel_chat(
4096    request: proto::JoinChannelChat,
4097    response: Response<proto::JoinChannelChat>,
4098    session: Session,
4099) -> Result<()> {
4100    let channel_id = ChannelId::from_proto(request.channel_id);
4101
4102    let db = session.db().await;
4103    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4104        .await?;
4105    let messages = db
4106        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4107        .await?;
4108    response.send(proto::JoinChannelChatResponse {
4109        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4110        messages,
4111    })?;
4112    Ok(())
4113}
4114
4115/// Stop receiving chat updates for a channel
4116async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4117    let channel_id = ChannelId::from_proto(request.channel_id);
4118    session
4119        .db()
4120        .await
4121        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4122        .await?;
4123    Ok(())
4124}
4125
4126/// Retrieve the chat history for a channel
4127async fn get_channel_messages(
4128    request: proto::GetChannelMessages,
4129    response: Response<proto::GetChannelMessages>,
4130    session: Session,
4131) -> Result<()> {
4132    let channel_id = ChannelId::from_proto(request.channel_id);
4133    let messages = session
4134        .db()
4135        .await
4136        .get_channel_messages(
4137            channel_id,
4138            session.user_id(),
4139            MESSAGE_COUNT_PER_PAGE,
4140            Some(MessageId::from_proto(request.before_message_id)),
4141        )
4142        .await?;
4143    response.send(proto::GetChannelMessagesResponse {
4144        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4145        messages,
4146    })?;
4147    Ok(())
4148}
4149
4150/// Retrieve specific chat messages
4151async fn get_channel_messages_by_id(
4152    request: proto::GetChannelMessagesById,
4153    response: Response<proto::GetChannelMessagesById>,
4154    session: Session,
4155) -> Result<()> {
4156    let message_ids = request
4157        .message_ids
4158        .iter()
4159        .map(|id| MessageId::from_proto(*id))
4160        .collect::<Vec<_>>();
4161    let messages = session
4162        .db()
4163        .await
4164        .get_channel_messages_by_id(session.user_id(), &message_ids)
4165        .await?;
4166    response.send(proto::GetChannelMessagesResponse {
4167        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4168        messages,
4169    })?;
4170    Ok(())
4171}
4172
4173/// Retrieve the current users notifications
4174async fn get_notifications(
4175    request: proto::GetNotifications,
4176    response: Response<proto::GetNotifications>,
4177    session: Session,
4178) -> Result<()> {
4179    let notifications = session
4180        .db()
4181        .await
4182        .get_notifications(
4183            session.user_id(),
4184            NOTIFICATION_COUNT_PER_PAGE,
4185            request.before_id.map(db::NotificationId::from_proto),
4186        )
4187        .await?;
4188    response.send(proto::GetNotificationsResponse {
4189        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4190        notifications,
4191    })?;
4192    Ok(())
4193}
4194
4195/// Mark notifications as read
4196async fn mark_notification_as_read(
4197    request: proto::MarkNotificationRead,
4198    response: Response<proto::MarkNotificationRead>,
4199    session: Session,
4200) -> Result<()> {
4201    let database = &session.db().await;
4202    let notifications = database
4203        .mark_notification_as_read_by_id(
4204            session.user_id(),
4205            NotificationId::from_proto(request.notification_id),
4206        )
4207        .await?;
4208    send_notifications(
4209        &*session.connection_pool().await,
4210        &session.peer,
4211        notifications,
4212    );
4213    response.send(proto::Ack {})?;
4214    Ok(())
4215}
4216
4217/// Get the current users information
4218async fn get_private_user_info(
4219    _request: proto::GetPrivateUserInfo,
4220    response: Response<proto::GetPrivateUserInfo>,
4221    session: Session,
4222) -> Result<()> {
4223    let db = session.db().await;
4224
4225    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4226    let user = db
4227        .get_user_by_id(session.user_id())
4228        .await?
4229        .context("user not found")?;
4230    let flags = db.get_user_flags(session.user_id()).await?;
4231
4232    response.send(proto::GetPrivateUserInfoResponse {
4233        metrics_id,
4234        staff: user.admin,
4235        flags,
4236        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4237    })?;
4238    Ok(())
4239}
4240
4241/// Accept the terms of service (tos) on behalf of the current user
4242async fn accept_terms_of_service(
4243    _request: proto::AcceptTermsOfService,
4244    response: Response<proto::AcceptTermsOfService>,
4245    session: Session,
4246) -> Result<()> {
4247    let db = session.db().await;
4248
4249    let accepted_tos_at = Utc::now();
4250    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4251        .await?;
4252
4253    response.send(proto::AcceptTermsOfServiceResponse {
4254        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4255    })?;
4256
4257    // When the user accepts the terms of service, we want to refresh their LLM
4258    // token to grant access.
4259    session
4260        .peer
4261        .send(session.connection_id, proto::RefreshLlmToken {})?;
4262
4263    Ok(())
4264}
4265
4266async fn get_llm_api_token(
4267    _request: proto::GetLlmToken,
4268    response: Response<proto::GetLlmToken>,
4269    session: Session,
4270) -> Result<()> {
4271    let db = session.db().await;
4272
4273    let flags = db.get_user_flags(session.user_id()).await?;
4274
4275    let user_id = session.user_id();
4276    let user = db
4277        .get_user_by_id(user_id)
4278        .await?
4279        .with_context(|| format!("user {user_id} not found"))?;
4280
4281    if user.accepted_tos_at.is_none() {
4282        Err(anyhow!("terms of service not accepted"))?
4283    }
4284
4285    let stripe_client = session
4286        .app_state
4287        .stripe_client
4288        .as_ref()
4289        .context("failed to retrieve Stripe client")?;
4290
4291    let stripe_billing = session
4292        .app_state
4293        .stripe_billing
4294        .as_ref()
4295        .context("failed to retrieve Stripe billing object")?;
4296
4297    let billing_customer = if let Some(billing_customer) =
4298        db.get_billing_customer_by_user_id(user.id).await?
4299    {
4300        billing_customer
4301    } else {
4302        let customer_id = stripe_billing
4303            .find_or_create_customer_by_email(user.email_address.as_deref())
4304            .await?;
4305
4306        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4307            .await?
4308            .context("billing customer not found")?
4309    };
4310
4311    let billing_subscription =
4312        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4313            billing_subscription
4314        } else {
4315            let stripe_customer_id =
4316                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4317
4318            let stripe_subscription = stripe_billing
4319                .subscribe_to_zed_free(stripe_customer_id)
4320                .await?;
4321
4322            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4323                billing_customer_id: billing_customer.id,
4324                kind: Some(SubscriptionKind::ZedFree),
4325                stripe_subscription_id: stripe_subscription.id.to_string(),
4326                stripe_subscription_status: stripe_subscription.status.into(),
4327                stripe_cancellation_reason: None,
4328                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4329                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4330            })
4331            .await?
4332        };
4333
4334    let billing_preferences = db.get_billing_preferences(user.id).await?;
4335
4336    let token = LlmTokenClaims::create(
4337        &user,
4338        session.is_staff(),
4339        billing_customer,
4340        billing_preferences,
4341        &flags,
4342        billing_subscription,
4343        session.system_id.clone(),
4344        &session.app_state.config,
4345    )?;
4346    response.send(proto::GetLlmTokenResponse { token })?;
4347    Ok(())
4348}
4349
4350fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4351    let message = match message {
4352        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4353        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4354        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4355        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4356        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4357            code: frame.code.into(),
4358            reason: frame.reason.as_str().to_owned().into(),
4359        })),
4360        // We should never receive a frame while reading the message, according
4361        // to the `tungstenite` maintainers:
4362        //
4363        // > It cannot occur when you read messages from the WebSocket, but it
4364        // > can be used when you want to send the raw frames (e.g. you want to
4365        // > send the frames to the WebSocket without composing the full message first).
4366        // >
4367        // > — https://github.com/snapview/tungstenite-rs/issues/268
4368        TungsteniteMessage::Frame(_) => {
4369            bail!("received an unexpected frame while reading the message")
4370        }
4371    };
4372
4373    Ok(message)
4374}
4375
4376fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4377    match message {
4378        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4379        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4380        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4381        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4382        AxumMessage::Close(frame) => {
4383            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4384                code: frame.code.into(),
4385                reason: frame.reason.as_ref().into(),
4386            }))
4387        }
4388    }
4389}
4390
4391fn notify_membership_updated(
4392    connection_pool: &mut ConnectionPool,
4393    result: MembershipUpdated,
4394    user_id: UserId,
4395    peer: &Peer,
4396) {
4397    for membership in &result.new_channels.channel_memberships {
4398        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4399    }
4400    for channel_id in &result.removed_channels {
4401        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4402    }
4403
4404    let user_channels_update = proto::UpdateUserChannels {
4405        channel_memberships: result
4406            .new_channels
4407            .channel_memberships
4408            .iter()
4409            .map(|cm| proto::ChannelMembership {
4410                channel_id: cm.channel_id.to_proto(),
4411                role: cm.role.into(),
4412            })
4413            .collect(),
4414        ..Default::default()
4415    };
4416
4417    let mut update = build_channels_update(result.new_channels);
4418    update.delete_channels = result
4419        .removed_channels
4420        .into_iter()
4421        .map(|id| id.to_proto())
4422        .collect();
4423    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4424
4425    for connection_id in connection_pool.user_connection_ids(user_id) {
4426        peer.send(connection_id, user_channels_update.clone())
4427            .trace_err();
4428        peer.send(connection_id, update.clone()).trace_err();
4429    }
4430}
4431
4432fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4433    proto::UpdateUserChannels {
4434        channel_memberships: channels
4435            .channel_memberships
4436            .iter()
4437            .map(|m| proto::ChannelMembership {
4438                channel_id: m.channel_id.to_proto(),
4439                role: m.role.into(),
4440            })
4441            .collect(),
4442        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4443        observed_channel_message_id: channels.observed_channel_messages.clone(),
4444    }
4445}
4446
4447fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4448    let mut update = proto::UpdateChannels::default();
4449
4450    for channel in channels.channels {
4451        update.channels.push(channel.to_proto());
4452    }
4453
4454    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4455    update.latest_channel_message_ids = channels.latest_channel_messages;
4456
4457    for (channel_id, participants) in channels.channel_participants {
4458        update
4459            .channel_participants
4460            .push(proto::ChannelParticipants {
4461                channel_id: channel_id.to_proto(),
4462                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4463            });
4464    }
4465
4466    for channel in channels.invited_channels {
4467        update.channel_invitations.push(channel.to_proto());
4468    }
4469
4470    update
4471}
4472
4473fn build_initial_contacts_update(
4474    contacts: Vec<db::Contact>,
4475    pool: &ConnectionPool,
4476) -> proto::UpdateContacts {
4477    let mut update = proto::UpdateContacts::default();
4478
4479    for contact in contacts {
4480        match contact {
4481            db::Contact::Accepted { user_id, busy } => {
4482                update.contacts.push(contact_for_user(user_id, busy, pool));
4483            }
4484            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4485            db::Contact::Incoming { user_id } => {
4486                update
4487                    .incoming_requests
4488                    .push(proto::IncomingContactRequest {
4489                        requester_id: user_id.to_proto(),
4490                    })
4491            }
4492        }
4493    }
4494
4495    update
4496}
4497
4498fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4499    proto::Contact {
4500        user_id: user_id.to_proto(),
4501        online: pool.is_user_online(user_id),
4502        busy,
4503    }
4504}
4505
4506fn room_updated(room: &proto::Room, peer: &Peer) {
4507    broadcast(
4508        None,
4509        room.participants
4510            .iter()
4511            .filter_map(|participant| Some(participant.peer_id?.into())),
4512        |peer_id| {
4513            peer.send(
4514                peer_id,
4515                proto::RoomUpdated {
4516                    room: Some(room.clone()),
4517                },
4518            )
4519        },
4520    );
4521}
4522
4523fn channel_updated(
4524    channel: &db::channel::Model,
4525    room: &proto::Room,
4526    peer: &Peer,
4527    pool: &ConnectionPool,
4528) {
4529    let participants = room
4530        .participants
4531        .iter()
4532        .map(|p| p.user_id)
4533        .collect::<Vec<_>>();
4534
4535    broadcast(
4536        None,
4537        pool.channel_connection_ids(channel.root_id())
4538            .filter_map(|(channel_id, role)| {
4539                role.can_see_channel(channel.visibility)
4540                    .then_some(channel_id)
4541            }),
4542        |peer_id| {
4543            peer.send(
4544                peer_id,
4545                proto::UpdateChannels {
4546                    channel_participants: vec![proto::ChannelParticipants {
4547                        channel_id: channel.id.to_proto(),
4548                        participant_user_ids: participants.clone(),
4549                    }],
4550                    ..Default::default()
4551                },
4552            )
4553        },
4554    );
4555}
4556
4557async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4558    let db = session.db().await;
4559
4560    let contacts = db.get_contacts(user_id).await?;
4561    let busy = db.is_user_busy(user_id).await?;
4562
4563    let pool = session.connection_pool().await;
4564    let updated_contact = contact_for_user(user_id, busy, &pool);
4565    for contact in contacts {
4566        if let db::Contact::Accepted {
4567            user_id: contact_user_id,
4568            ..
4569        } = contact
4570        {
4571            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4572                session
4573                    .peer
4574                    .send(
4575                        contact_conn_id,
4576                        proto::UpdateContacts {
4577                            contacts: vec![updated_contact.clone()],
4578                            remove_contacts: Default::default(),
4579                            incoming_requests: Default::default(),
4580                            remove_incoming_requests: Default::default(),
4581                            outgoing_requests: Default::default(),
4582                            remove_outgoing_requests: Default::default(),
4583                        },
4584                    )
4585                    .trace_err();
4586            }
4587        }
4588    }
4589    Ok(())
4590}
4591
4592async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4593    let mut contacts_to_update = HashSet::default();
4594
4595    let room_id;
4596    let canceled_calls_to_user_ids;
4597    let livekit_room;
4598    let delete_livekit_room;
4599    let room;
4600    let channel;
4601
4602    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4603        contacts_to_update.insert(session.user_id());
4604
4605        for project in left_room.left_projects.values() {
4606            project_left(project, session);
4607        }
4608
4609        room_id = RoomId::from_proto(left_room.room.id);
4610        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4611        livekit_room = mem::take(&mut left_room.room.livekit_room);
4612        delete_livekit_room = left_room.deleted;
4613        room = mem::take(&mut left_room.room);
4614        channel = mem::take(&mut left_room.channel);
4615
4616        room_updated(&room, &session.peer);
4617    } else {
4618        return Ok(());
4619    }
4620
4621    if let Some(channel) = channel {
4622        channel_updated(
4623            &channel,
4624            &room,
4625            &session.peer,
4626            &*session.connection_pool().await,
4627        );
4628    }
4629
4630    {
4631        let pool = session.connection_pool().await;
4632        for canceled_user_id in canceled_calls_to_user_ids {
4633            for connection_id in pool.user_connection_ids(canceled_user_id) {
4634                session
4635                    .peer
4636                    .send(
4637                        connection_id,
4638                        proto::CallCanceled {
4639                            room_id: room_id.to_proto(),
4640                        },
4641                    )
4642                    .trace_err();
4643            }
4644            contacts_to_update.insert(canceled_user_id);
4645        }
4646    }
4647
4648    for contact_user_id in contacts_to_update {
4649        update_user_contacts(contact_user_id, session).await?;
4650    }
4651
4652    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4653        live_kit
4654            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4655            .await
4656            .trace_err();
4657
4658        if delete_livekit_room {
4659            live_kit.delete_room(livekit_room).await.trace_err();
4660        }
4661    }
4662
4663    Ok(())
4664}
4665
4666async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4667    let left_channel_buffers = session
4668        .db()
4669        .await
4670        .leave_channel_buffers(session.connection_id)
4671        .await?;
4672
4673    for left_buffer in left_channel_buffers {
4674        channel_buffer_updated(
4675            session.connection_id,
4676            left_buffer.connections,
4677            &proto::UpdateChannelBufferCollaborators {
4678                channel_id: left_buffer.channel_id.to_proto(),
4679                collaborators: left_buffer.collaborators,
4680            },
4681            &session.peer,
4682        );
4683    }
4684
4685    Ok(())
4686}
4687
4688fn project_left(project: &db::LeftProject, session: &Session) {
4689    for connection_id in &project.connection_ids {
4690        if project.should_unshare {
4691            session
4692                .peer
4693                .send(
4694                    *connection_id,
4695                    proto::UnshareProject {
4696                        project_id: project.id.to_proto(),
4697                    },
4698                )
4699                .trace_err();
4700        } else {
4701            session
4702                .peer
4703                .send(
4704                    *connection_id,
4705                    proto::RemoveProjectCollaborator {
4706                        project_id: project.id.to_proto(),
4707                        peer_id: Some(session.connection_id.into()),
4708                    },
4709                )
4710                .trace_err();
4711        }
4712    }
4713}
4714
4715pub trait ResultExt {
4716    type Ok;
4717
4718    fn trace_err(self) -> Option<Self::Ok>;
4719}
4720
4721impl<T, E> ResultExt for Result<T, E>
4722where
4723    E: std::fmt::Debug,
4724{
4725    type Ok = T;
4726
4727    #[track_caller]
4728    fn trace_err(self) -> Option<T> {
4729        match self {
4730            Ok(value) => Some(value),
4731            Err(error) => {
4732                tracing::error!("{:?}", error);
4733                None
4734            }
4735        }
4736    }
4737}