rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{
   8    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
   9    MIN_ACCOUNT_AGE_FOR_LLM_USE,
  10};
  11use crate::stripe_client::StripeCustomerId;
  12use crate::{
  13    AppState, Error, Result, auth,
  14    db::{
  15        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  16        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  17        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  18        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  19    },
  20    executor::Executor,
  21};
  22use anyhow::{Context as _, anyhow, bail};
  23use async_tungstenite::tungstenite::{
  24    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  25};
  26use axum::headers::UserAgent;
  27use axum::{
  28    Extension, Router, TypedHeader,
  29    body::Body,
  30    extract::{
  31        ConnectInfo, WebSocketUpgrade,
  32        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  33    },
  34    headers::{Header, HeaderName},
  35    http::StatusCode,
  36    middleware,
  37    response::IntoResponse,
  38    routing::get,
  39};
  40use chrono::Utc;
  41use collections::{HashMap, HashSet};
  42pub use connection_pool::{ConnectionPool, ZedVersion};
  43use core::fmt::{self, Debug, Formatter};
  44use reqwest_client::ReqwestClient;
  45use rpc::proto::{MultiLspQuery, split_repository_update};
  46use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  47
  48use futures::{
  49    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  50    stream::FuturesUnordered,
  51};
  52use prometheus::{IntGauge, register_int_gauge};
  53use rpc::{
  54    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  55    proto::{
  56        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  57        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  58    },
  59};
  60use semantic_version::SemanticVersion;
  61use serde::{Serialize, Serializer};
  62use std::{
  63    any::TypeId,
  64    future::Future,
  65    marker::PhantomData,
  66    mem,
  67    net::SocketAddr,
  68    ops::{Deref, DerefMut},
  69    rc::Rc,
  70    sync::{
  71        Arc, OnceLock,
  72        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  73    },
  74    time::{Duration, Instant},
  75};
  76use time::OffsetDateTime;
  77use tokio::sync::{Semaphore, watch};
  78use tower::ServiceBuilder;
  79use tracing::{
  80    Instrument,
  81    field::{self},
  82    info_span, instrument,
  83};
  84
  85pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  86
  87// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  88pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  89
  90const MESSAGE_COUNT_PER_PAGE: usize = 100;
  91const MAX_MESSAGE_LEN: usize = 1024;
  92const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  93const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  94
  95static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  96
  97type MessageHandler =
  98    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  99
 100pub struct ConnectionGuard;
 101
 102impl ConnectionGuard {
 103    pub fn try_acquire() -> Result<Self, ()> {
 104        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 105        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 106            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 107            tracing::error!(
 108                "too many concurrent connections: {}",
 109                current_connections + 1
 110            );
 111            return Err(());
 112        }
 113        Ok(ConnectionGuard)
 114    }
 115}
 116
 117impl Drop for ConnectionGuard {
 118    fn drop(&mut self) {
 119        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 120    }
 121}
 122
 123struct Response<R> {
 124    peer: Arc<Peer>,
 125    receipt: Receipt<R>,
 126    responded: Arc<AtomicBool>,
 127}
 128
 129impl<R: RequestMessage> Response<R> {
 130    fn send(self, payload: R::Response) -> Result<()> {
 131        self.responded.store(true, SeqCst);
 132        self.peer.respond(self.receipt, payload)?;
 133        Ok(())
 134    }
 135}
 136
 137#[derive(Clone, Debug)]
 138pub enum Principal {
 139    User(User),
 140    Impersonated { user: User, admin: User },
 141}
 142
 143impl Principal {
 144    fn user(&self) -> &User {
 145        match self {
 146            Principal::User(user) => user,
 147            Principal::Impersonated { user, .. } => user,
 148        }
 149    }
 150
 151    fn update_span(&self, span: &tracing::Span) {
 152        match &self {
 153            Principal::User(user) => {
 154                span.record("user_id", user.id.0);
 155                span.record("login", &user.github_login);
 156            }
 157            Principal::Impersonated { user, admin } => {
 158                span.record("user_id", user.id.0);
 159                span.record("login", &user.github_login);
 160                span.record("impersonator", &admin.github_login);
 161            }
 162        }
 163    }
 164}
 165
 166#[derive(Clone)]
 167struct Session {
 168    principal: Principal,
 169    connection_id: ConnectionId,
 170    db: Arc<tokio::sync::Mutex<DbHandle>>,
 171    peer: Arc<Peer>,
 172    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 173    app_state: Arc<AppState>,
 174    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 175    /// The GeoIP country code for the user.
 176    #[allow(unused)]
 177    geoip_country_code: Option<String>,
 178    system_id: Option<String>,
 179    _executor: Executor,
 180}
 181
 182impl Session {
 183    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 184        #[cfg(test)]
 185        tokio::task::yield_now().await;
 186        let guard = self.db.lock().await;
 187        #[cfg(test)]
 188        tokio::task::yield_now().await;
 189        guard
 190    }
 191
 192    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 193        #[cfg(test)]
 194        tokio::task::yield_now().await;
 195        let guard = self.connection_pool.lock();
 196        ConnectionPoolGuard {
 197            guard,
 198            _not_send: PhantomData,
 199        }
 200    }
 201
 202    fn is_staff(&self) -> bool {
 203        match &self.principal {
 204            Principal::User(user) => user.admin,
 205            Principal::Impersonated { .. } => true,
 206        }
 207    }
 208
 209    fn user_id(&self) -> UserId {
 210        match &self.principal {
 211            Principal::User(user) => user.id,
 212            Principal::Impersonated { user, .. } => user.id,
 213        }
 214    }
 215
 216    pub fn email(&self) -> Option<String> {
 217        match &self.principal {
 218            Principal::User(user) => user.email_address.clone(),
 219            Principal::Impersonated { user, .. } => user.email_address.clone(),
 220        }
 221    }
 222}
 223
 224impl Debug for Session {
 225    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 226        let mut result = f.debug_struct("Session");
 227        match &self.principal {
 228            Principal::User(user) => {
 229                result.field("user", &user.github_login);
 230            }
 231            Principal::Impersonated { user, admin } => {
 232                result.field("user", &user.github_login);
 233                result.field("impersonator", &admin.github_login);
 234            }
 235        }
 236        result.field("connection_id", &self.connection_id).finish()
 237    }
 238}
 239
 240struct DbHandle(Arc<Database>);
 241
 242impl Deref for DbHandle {
 243    type Target = Database;
 244
 245    fn deref(&self) -> &Self::Target {
 246        self.0.as_ref()
 247    }
 248}
 249
 250pub struct Server {
 251    id: parking_lot::Mutex<ServerId>,
 252    peer: Arc<Peer>,
 253    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 254    app_state: Arc<AppState>,
 255    handlers: HashMap<TypeId, MessageHandler>,
 256    teardown: watch::Sender<bool>,
 257}
 258
 259pub(crate) struct ConnectionPoolGuard<'a> {
 260    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 261    _not_send: PhantomData<Rc<()>>,
 262}
 263
 264#[derive(Serialize)]
 265pub struct ServerSnapshot<'a> {
 266    peer: &'a Peer,
 267    #[serde(serialize_with = "serialize_deref")]
 268    connection_pool: ConnectionPoolGuard<'a>,
 269}
 270
 271pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 272where
 273    S: Serializer,
 274    T: Deref<Target = U>,
 275    U: Serialize,
 276{
 277    Serialize::serialize(value.deref(), serializer)
 278}
 279
 280impl Server {
 281    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 282        let mut server = Self {
 283            id: parking_lot::Mutex::new(id),
 284            peer: Peer::new(id.0 as u32),
 285            app_state: app_state.clone(),
 286            connection_pool: Default::default(),
 287            handlers: Default::default(),
 288            teardown: watch::channel(false).0,
 289        };
 290
 291        server
 292            .add_request_handler(ping)
 293            .add_request_handler(create_room)
 294            .add_request_handler(join_room)
 295            .add_request_handler(rejoin_room)
 296            .add_request_handler(leave_room)
 297            .add_request_handler(set_room_participant_role)
 298            .add_request_handler(call)
 299            .add_request_handler(cancel_call)
 300            .add_message_handler(decline_call)
 301            .add_request_handler(update_participant_location)
 302            .add_request_handler(share_project)
 303            .add_message_handler(unshare_project)
 304            .add_request_handler(join_project)
 305            .add_message_handler(leave_project)
 306            .add_request_handler(update_project)
 307            .add_request_handler(update_worktree)
 308            .add_request_handler(update_repository)
 309            .add_request_handler(remove_repository)
 310            .add_message_handler(start_language_server)
 311            .add_message_handler(update_language_server)
 312            .add_message_handler(update_diagnostic_summary)
 313            .add_message_handler(update_worktree_settings)
 314            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 316            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 317            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 318            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 319            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 320            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 321            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 322            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 323            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 324            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 325            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 326            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 327            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 328            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 329            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 330            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 331            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 332            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 333            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 334            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 335            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 336            .add_request_handler(
 337                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 338            )
 339            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 340            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 341            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 342            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 344            .add_request_handler(
 345                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 346            )
 347            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 348            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 349            .add_request_handler(
 350                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 351            )
 352            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 353            .add_request_handler(
 354                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 355            )
 356            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 357            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 358            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 359            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 360            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 361            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 362            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 363            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 364            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 365            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 366            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 367            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 368            .add_request_handler(
 369                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 370            )
 371            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 372            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 373            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 374            .add_request_handler(multi_lsp_query)
 375            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 376            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 377            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 378            .add_message_handler(create_buffer_for_peer)
 379            .add_request_handler(update_buffer)
 380            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 381            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 382            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 383            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 384            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 385            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 386            .add_message_handler(
 387                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 388            )
 389            .add_request_handler(get_users)
 390            .add_request_handler(fuzzy_search_users)
 391            .add_request_handler(request_contact)
 392            .add_request_handler(remove_contact)
 393            .add_request_handler(respond_to_contact_request)
 394            .add_message_handler(subscribe_to_channels)
 395            .add_request_handler(create_channel)
 396            .add_request_handler(delete_channel)
 397            .add_request_handler(invite_channel_member)
 398            .add_request_handler(remove_channel_member)
 399            .add_request_handler(set_channel_member_role)
 400            .add_request_handler(set_channel_visibility)
 401            .add_request_handler(rename_channel)
 402            .add_request_handler(join_channel_buffer)
 403            .add_request_handler(leave_channel_buffer)
 404            .add_message_handler(update_channel_buffer)
 405            .add_request_handler(rejoin_channel_buffers)
 406            .add_request_handler(get_channel_members)
 407            .add_request_handler(respond_to_channel_invite)
 408            .add_request_handler(join_channel)
 409            .add_request_handler(join_channel_chat)
 410            .add_message_handler(leave_channel_chat)
 411            .add_request_handler(send_channel_message)
 412            .add_request_handler(remove_channel_message)
 413            .add_request_handler(update_channel_message)
 414            .add_request_handler(get_channel_messages)
 415            .add_request_handler(get_channel_messages_by_id)
 416            .add_request_handler(get_notifications)
 417            .add_request_handler(mark_notification_as_read)
 418            .add_request_handler(move_channel)
 419            .add_request_handler(reorder_channel)
 420            .add_request_handler(follow)
 421            .add_message_handler(unfollow)
 422            .add_message_handler(update_followers)
 423            .add_request_handler(get_private_user_info)
 424            .add_request_handler(get_llm_api_token)
 425            .add_request_handler(accept_terms_of_service)
 426            .add_message_handler(acknowledge_channel_message)
 427            .add_message_handler(acknowledge_buffer_version)
 428            .add_request_handler(get_supermaven_api_key)
 429            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 430            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 431            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 432            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 433            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 434            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 435            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 436            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 437            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 438            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 439            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 440            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 441            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 442            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 443            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 444            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 445            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 446            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 447            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 448            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 449            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 450            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 451            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 452            .add_message_handler(update_context);
 453
 454        Arc::new(server)
 455    }
 456
 457    pub async fn start(&self) -> Result<()> {
 458        let server_id = *self.id.lock();
 459        let app_state = self.app_state.clone();
 460        let peer = self.peer.clone();
 461        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 462        let pool = self.connection_pool.clone();
 463        let livekit_client = self.app_state.livekit_client.clone();
 464
 465        let span = info_span!("start server");
 466        self.app_state.executor.spawn_detached(
 467            async move {
 468                tracing::info!("waiting for cleanup timeout");
 469                timeout.await;
 470                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 471
 472                app_state
 473                    .db
 474                    .delete_stale_channel_chat_participants(
 475                        &app_state.config.zed_environment,
 476                        server_id,
 477                    )
 478                    .await
 479                    .trace_err();
 480
 481                if let Some((room_ids, channel_ids)) = app_state
 482                    .db
 483                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 484                    .await
 485                    .trace_err()
 486                {
 487                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 488                    tracing::info!(
 489                        stale_channel_buffer_count = channel_ids.len(),
 490                        "retrieved stale channel buffers"
 491                    );
 492
 493                    for channel_id in channel_ids {
 494                        if let Some(refreshed_channel_buffer) = app_state
 495                            .db
 496                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 497                            .await
 498                            .trace_err()
 499                        {
 500                            for connection_id in refreshed_channel_buffer.connection_ids {
 501                                peer.send(
 502                                    connection_id,
 503                                    proto::UpdateChannelBufferCollaborators {
 504                                        channel_id: channel_id.to_proto(),
 505                                        collaborators: refreshed_channel_buffer
 506                                            .collaborators
 507                                            .clone(),
 508                                    },
 509                                )
 510                                .trace_err();
 511                            }
 512                        }
 513                    }
 514
 515                    for room_id in room_ids {
 516                        let mut contacts_to_update = HashSet::default();
 517                        let mut canceled_calls_to_user_ids = Vec::new();
 518                        let mut livekit_room = String::new();
 519                        let mut delete_livekit_room = false;
 520
 521                        if let Some(mut refreshed_room) = app_state
 522                            .db
 523                            .clear_stale_room_participants(room_id, server_id)
 524                            .await
 525                            .trace_err()
 526                        {
 527                            tracing::info!(
 528                                room_id = room_id.0,
 529                                new_participant_count = refreshed_room.room.participants.len(),
 530                                "refreshed room"
 531                            );
 532                            room_updated(&refreshed_room.room, &peer);
 533                            if let Some(channel) = refreshed_room.channel.as_ref() {
 534                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 535                            }
 536                            contacts_to_update
 537                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 538                            contacts_to_update
 539                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 540                            canceled_calls_to_user_ids =
 541                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 542                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 543                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 544                        }
 545
 546                        {
 547                            let pool = pool.lock();
 548                            for canceled_user_id in canceled_calls_to_user_ids {
 549                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 550                                    peer.send(
 551                                        connection_id,
 552                                        proto::CallCanceled {
 553                                            room_id: room_id.to_proto(),
 554                                        },
 555                                    )
 556                                    .trace_err();
 557                                }
 558                            }
 559                        }
 560
 561                        for user_id in contacts_to_update {
 562                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 563                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 564                            if let Some((busy, contacts)) = busy.zip(contacts) {
 565                                let pool = pool.lock();
 566                                let updated_contact = contact_for_user(user_id, busy, &pool);
 567                                for contact in contacts {
 568                                    if let db::Contact::Accepted {
 569                                        user_id: contact_user_id,
 570                                        ..
 571                                    } = contact
 572                                    {
 573                                        for contact_conn_id in
 574                                            pool.user_connection_ids(contact_user_id)
 575                                        {
 576                                            peer.send(
 577                                                contact_conn_id,
 578                                                proto::UpdateContacts {
 579                                                    contacts: vec![updated_contact.clone()],
 580                                                    remove_contacts: Default::default(),
 581                                                    incoming_requests: Default::default(),
 582                                                    remove_incoming_requests: Default::default(),
 583                                                    outgoing_requests: Default::default(),
 584                                                    remove_outgoing_requests: Default::default(),
 585                                                },
 586                                            )
 587                                            .trace_err();
 588                                        }
 589                                    }
 590                                }
 591                            }
 592                        }
 593
 594                        if let Some(live_kit) = livekit_client.as_ref() {
 595                            if delete_livekit_room {
 596                                live_kit.delete_room(livekit_room).await.trace_err();
 597                            }
 598                        }
 599                    }
 600                }
 601
 602                app_state
 603                    .db
 604                    .delete_stale_channel_chat_participants(
 605                        &app_state.config.zed_environment,
 606                        server_id,
 607                    )
 608                    .await
 609                    .trace_err();
 610
 611                app_state
 612                    .db
 613                    .clear_old_worktree_entries(server_id)
 614                    .await
 615                    .trace_err();
 616
 617                app_state
 618                    .db
 619                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 620                    .await
 621                    .trace_err();
 622            }
 623            .instrument(span),
 624        );
 625        Ok(())
 626    }
 627
 628    pub fn teardown(&self) {
 629        self.peer.teardown();
 630        self.connection_pool.lock().reset();
 631        let _ = self.teardown.send(true);
 632    }
 633
 634    #[cfg(test)]
 635    pub fn reset(&self, id: ServerId) {
 636        self.teardown();
 637        *self.id.lock() = id;
 638        self.peer.reset(id.0 as u32);
 639        let _ = self.teardown.send(false);
 640    }
 641
 642    #[cfg(test)]
 643    pub fn id(&self) -> ServerId {
 644        *self.id.lock()
 645    }
 646
 647    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 648    where
 649        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 650        Fut: 'static + Send + Future<Output = Result<()>>,
 651        M: EnvelopedMessage,
 652    {
 653        let prev_handler = self.handlers.insert(
 654            TypeId::of::<M>(),
 655            Box::new(move |envelope, session| {
 656                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 657                let received_at = envelope.received_at;
 658                tracing::info!("message received");
 659                let start_time = Instant::now();
 660                let future = (handler)(*envelope, session);
 661                async move {
 662                    let result = future.await;
 663                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 664                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 665                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 666
 667                    match result {
 668                        Err(error) => {
 669                            tracing::error!(
 670                                ?error,
 671                                total_duration_ms,
 672                                processing_duration_ms,
 673                                queue_duration_ms,
 674                                "error handling message"
 675                            )
 676                        }
 677                        Ok(()) => tracing::info!(
 678                            total_duration_ms,
 679                            processing_duration_ms,
 680                            queue_duration_ms,
 681                            "finished handling message"
 682                        ),
 683                    }
 684                }
 685                .boxed()
 686            }),
 687        );
 688        if prev_handler.is_some() {
 689            panic!("registered a handler for the same message twice");
 690        }
 691        self
 692    }
 693
 694    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 695    where
 696        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 697        Fut: 'static + Send + Future<Output = Result<()>>,
 698        M: EnvelopedMessage,
 699    {
 700        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 701        self
 702    }
 703
 704    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 705    where
 706        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 707        Fut: Send + Future<Output = Result<()>>,
 708        M: RequestMessage,
 709    {
 710        let handler = Arc::new(handler);
 711        self.add_handler(move |envelope, session| {
 712            let receipt = envelope.receipt();
 713            let handler = handler.clone();
 714            async move {
 715                let peer = session.peer.clone();
 716                let responded = Arc::new(AtomicBool::default());
 717                let response = Response {
 718                    peer: peer.clone(),
 719                    responded: responded.clone(),
 720                    receipt,
 721                };
 722                match (handler)(envelope.payload, response, session).await {
 723                    Ok(()) => {
 724                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 725                            Ok(())
 726                        } else {
 727                            Err(anyhow!("handler did not send a response"))?
 728                        }
 729                    }
 730                    Err(error) => {
 731                        let proto_err = match &error {
 732                            Error::Internal(err) => err.to_proto(),
 733                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 734                        };
 735                        peer.respond_with_error(receipt, proto_err)?;
 736                        Err(error)
 737                    }
 738                }
 739            }
 740        })
 741    }
 742
 743    pub fn handle_connection(
 744        self: &Arc<Self>,
 745        connection: Connection,
 746        address: String,
 747        principal: Principal,
 748        zed_version: ZedVersion,
 749        user_agent: Option<String>,
 750        geoip_country_code: Option<String>,
 751        system_id: Option<String>,
 752        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 753        executor: Executor,
 754        connection_guard: Option<ConnectionGuard>,
 755    ) -> impl Future<Output = ()> + use<> {
 756        let this = self.clone();
 757        let span = info_span!("handle connection", %address,
 758            connection_id=field::Empty,
 759            user_id=field::Empty,
 760            login=field::Empty,
 761            impersonator=field::Empty,
 762            user_agent=field::Empty,
 763            geoip_country_code=field::Empty
 764        );
 765        principal.update_span(&span);
 766        if let Some(user_agent) = user_agent {
 767            span.record("user_agent", user_agent);
 768        }
 769
 770        if let Some(country_code) = geoip_country_code.as_ref() {
 771            span.record("geoip_country_code", country_code);
 772        }
 773
 774        let mut teardown = self.teardown.subscribe();
 775        async move {
 776            if *teardown.borrow() {
 777                tracing::error!("server is tearing down");
 778                return;
 779            }
 780
 781            let (connection_id, handle_io, mut incoming_rx) =
 782                this.peer.add_connection(connection, {
 783                    let executor = executor.clone();
 784                    move |duration| executor.sleep(duration)
 785                });
 786            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 787
 788            tracing::info!("connection opened");
 789
 790            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 791            let http_client = match ReqwestClient::user_agent(&user_agent) {
 792                Ok(http_client) => Arc::new(http_client),
 793                Err(error) => {
 794                    tracing::error!(?error, "failed to create HTTP client");
 795                    return;
 796                }
 797            };
 798
 799            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 800                |supermaven_admin_api_key| {
 801                    Arc::new(SupermavenAdminApi::new(
 802                        supermaven_admin_api_key.to_string(),
 803                        http_client.clone(),
 804                    ))
 805                },
 806            );
 807
 808            let session = Session {
 809                principal: principal.clone(),
 810                connection_id,
 811                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 812                peer: this.peer.clone(),
 813                connection_pool: this.connection_pool.clone(),
 814                app_state: this.app_state.clone(),
 815                geoip_country_code,
 816                system_id,
 817                _executor: executor.clone(),
 818                supermaven_client,
 819            };
 820
 821            if let Err(error) = this
 822                .send_initial_client_update(
 823                    connection_id,
 824                    zed_version,
 825                    send_connection_id,
 826                    &session,
 827                )
 828                .await
 829            {
 830                tracing::error!(?error, "failed to send initial client update");
 831                return;
 832            }
 833            drop(connection_guard);
 834
 835            let handle_io = handle_io.fuse();
 836            futures::pin_mut!(handle_io);
 837
 838            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 839            // This prevents deadlocks when e.g., client A performs a request to client B and
 840            // client B performs a request to client A. If both clients stop processing further
 841            // messages until their respective request completes, they won't have a chance to
 842            // respond to the other client's request and cause a deadlock.
 843            //
 844            // This arrangement ensures we will attempt to process earlier messages first, but fall
 845            // back to processing messages arrived later in the spirit of making progress.
 846            const MAX_CONCURRENT_HANDLERS: usize = 256;
 847            let mut foreground_message_handlers = FuturesUnordered::new();
 848            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 849            let get_concurrent_handlers = {
 850                let concurrent_handlers = concurrent_handlers.clone();
 851                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 852            };
 853            loop {
 854                let next_message = async {
 855                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 856                    let message = incoming_rx.next().await;
 857                    // Cache the concurrent_handlers here, so that we know what the
 858                    // queue looks like as each handler starts
 859                    (permit, message, get_concurrent_handlers())
 860                }
 861                .fuse();
 862                futures::pin_mut!(next_message);
 863                futures::select_biased! {
 864                    _ = teardown.changed().fuse() => return,
 865                    result = handle_io => {
 866                        if let Err(error) = result {
 867                            tracing::error!(?error, "error handling I/O");
 868                        }
 869                        break;
 870                    }
 871                    _ = foreground_message_handlers.next() => {}
 872                    next_message = next_message => {
 873                        let (permit, message, concurrent_handlers) = next_message;
 874                        if let Some(message) = message {
 875                            let type_name = message.payload_type_name();
 876                            // note: we copy all the fields from the parent span so we can query them in the logs.
 877                            // (https://github.com/tokio-rs/tracing/issues/2670).
 878                            let span = tracing::info_span!("receive message",
 879                                %connection_id,
 880                                %address,
 881                                type_name,
 882                                concurrent_handlers,
 883                                user_id=field::Empty,
 884                                login=field::Empty,
 885                                impersonator=field::Empty,
 886                                multi_lsp_query_request=field::Empty,
 887                            );
 888                            principal.update_span(&span);
 889                            let span_enter = span.enter();
 890                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 891                                let is_background = message.is_background();
 892                                let handle_message = (handler)(message, session.clone());
 893                                drop(span_enter);
 894
 895                                let handle_message = async move {
 896                                    handle_message.await;
 897                                    drop(permit);
 898                                }.instrument(span);
 899                                if is_background {
 900                                    executor.spawn_detached(handle_message);
 901                                } else {
 902                                    foreground_message_handlers.push(handle_message);
 903                                }
 904                            } else {
 905                                tracing::error!("no message handler");
 906                            }
 907                        } else {
 908                            tracing::info!("connection closed");
 909                            break;
 910                        }
 911                    }
 912                }
 913            }
 914
 915            drop(foreground_message_handlers);
 916            let concurrent_handlers = get_concurrent_handlers();
 917            tracing::info!(concurrent_handlers, "signing out");
 918            if let Err(error) = connection_lost(session, teardown, executor).await {
 919                tracing::error!(?error, "error signing out");
 920            }
 921        }
 922        .instrument(span)
 923    }
 924
 925    async fn send_initial_client_update(
 926        &self,
 927        connection_id: ConnectionId,
 928        zed_version: ZedVersion,
 929        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 930        session: &Session,
 931    ) -> Result<()> {
 932        self.peer.send(
 933            connection_id,
 934            proto::Hello {
 935                peer_id: Some(connection_id.into()),
 936            },
 937        )?;
 938        tracing::info!("sent hello message");
 939        if let Some(send_connection_id) = send_connection_id.take() {
 940            let _ = send_connection_id.send(connection_id);
 941        }
 942
 943        match &session.principal {
 944            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 945                if !user.connected_once {
 946                    self.peer.send(connection_id, proto::ShowContacts {})?;
 947                    self.app_state
 948                        .db
 949                        .set_user_connected_once(user.id, true)
 950                        .await?;
 951                }
 952
 953                update_user_plan(session).await?;
 954
 955                let contacts = self.app_state.db.get_contacts(user.id).await?;
 956
 957                {
 958                    let mut pool = self.connection_pool.lock();
 959                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 960                    self.peer.send(
 961                        connection_id,
 962                        build_initial_contacts_update(contacts, &pool),
 963                    )?;
 964                }
 965
 966                if should_auto_subscribe_to_channels(zed_version) {
 967                    subscribe_user_to_channels(user.id, session).await?;
 968                }
 969
 970                if let Some(incoming_call) =
 971                    self.app_state.db.incoming_call_for_user(user.id).await?
 972                {
 973                    self.peer.send(connection_id, incoming_call)?;
 974                }
 975
 976                update_user_contacts(user.id, session).await?;
 977            }
 978        }
 979
 980        Ok(())
 981    }
 982
 983    pub async fn invite_code_redeemed(
 984        self: &Arc<Self>,
 985        inviter_id: UserId,
 986        invitee_id: UserId,
 987    ) -> Result<()> {
 988        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 989            if let Some(code) = &user.invite_code {
 990                let pool = self.connection_pool.lock();
 991                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 992                for connection_id in pool.user_connection_ids(inviter_id) {
 993                    self.peer.send(
 994                        connection_id,
 995                        proto::UpdateContacts {
 996                            contacts: vec![invitee_contact.clone()],
 997                            ..Default::default()
 998                        },
 999                    )?;
1000                    self.peer.send(
1001                        connection_id,
1002                        proto::UpdateInviteInfo {
1003                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1004                            count: user.invite_count as u32,
1005                        },
1006                    )?;
1007                }
1008            }
1009        }
1010        Ok(())
1011    }
1012
1013    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1014        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1015            if let Some(invite_code) = &user.invite_code {
1016                let pool = self.connection_pool.lock();
1017                for connection_id in pool.user_connection_ids(user_id) {
1018                    self.peer.send(
1019                        connection_id,
1020                        proto::UpdateInviteInfo {
1021                            url: format!(
1022                                "{}{}",
1023                                self.app_state.config.invite_link_prefix, invite_code
1024                            ),
1025                            count: user.invite_count as u32,
1026                        },
1027                    )?;
1028                }
1029            }
1030        }
1031        Ok(())
1032    }
1033
1034    pub async fn update_plan_for_user(
1035        self: &Arc<Self>,
1036        user_id: UserId,
1037        update_user_plan: proto::UpdateUserPlan,
1038    ) -> Result<()> {
1039        let pool = self.connection_pool.lock();
1040        for connection_id in pool.user_connection_ids(user_id) {
1041            self.peer
1042                .send(connection_id, update_user_plan.clone())
1043                .trace_err();
1044        }
1045
1046        Ok(())
1047    }
1048
1049    /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1050    /// message on the Collab server.
1051    ///
1052    /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1053    pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1054        let user = self
1055            .app_state
1056            .db
1057            .get_user_by_id(user_id)
1058            .await?
1059            .context("user not found")?;
1060
1061        let update_user_plan = make_update_user_plan_message(
1062            &user,
1063            user.admin,
1064            &self.app_state.db,
1065            self.app_state.llm_db.clone(),
1066        )
1067        .await?;
1068
1069        self.update_plan_for_user(user_id, update_user_plan).await
1070    }
1071
1072    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1073        let pool = self.connection_pool.lock();
1074        for connection_id in pool.user_connection_ids(user_id) {
1075            self.peer
1076                .send(connection_id, proto::RefreshLlmToken {})
1077                .trace_err();
1078        }
1079    }
1080
1081    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1082        ServerSnapshot {
1083            connection_pool: ConnectionPoolGuard {
1084                guard: self.connection_pool.lock(),
1085                _not_send: PhantomData,
1086            },
1087            peer: &self.peer,
1088        }
1089    }
1090}
1091
1092impl Deref for ConnectionPoolGuard<'_> {
1093    type Target = ConnectionPool;
1094
1095    fn deref(&self) -> &Self::Target {
1096        &self.guard
1097    }
1098}
1099
1100impl DerefMut for ConnectionPoolGuard<'_> {
1101    fn deref_mut(&mut self) -> &mut Self::Target {
1102        &mut self.guard
1103    }
1104}
1105
1106impl Drop for ConnectionPoolGuard<'_> {
1107    fn drop(&mut self) {
1108        #[cfg(test)]
1109        self.check_invariants();
1110    }
1111}
1112
1113fn broadcast<F>(
1114    sender_id: Option<ConnectionId>,
1115    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1116    mut f: F,
1117) where
1118    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1119{
1120    for receiver_id in receiver_ids {
1121        if Some(receiver_id) != sender_id {
1122            if let Err(error) = f(receiver_id) {
1123                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1124            }
1125        }
1126    }
1127}
1128
1129pub struct ProtocolVersion(u32);
1130
1131impl Header for ProtocolVersion {
1132    fn name() -> &'static HeaderName {
1133        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1134        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1135    }
1136
1137    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1138    where
1139        Self: Sized,
1140        I: Iterator<Item = &'i axum::http::HeaderValue>,
1141    {
1142        let version = values
1143            .next()
1144            .ok_or_else(axum::headers::Error::invalid)?
1145            .to_str()
1146            .map_err(|_| axum::headers::Error::invalid())?
1147            .parse()
1148            .map_err(|_| axum::headers::Error::invalid())?;
1149        Ok(Self(version))
1150    }
1151
1152    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1153        values.extend([self.0.to_string().parse().unwrap()]);
1154    }
1155}
1156
1157pub struct AppVersionHeader(SemanticVersion);
1158impl Header for AppVersionHeader {
1159    fn name() -> &'static HeaderName {
1160        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1161        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1162    }
1163
1164    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1165    where
1166        Self: Sized,
1167        I: Iterator<Item = &'i axum::http::HeaderValue>,
1168    {
1169        let version = values
1170            .next()
1171            .ok_or_else(axum::headers::Error::invalid)?
1172            .to_str()
1173            .map_err(|_| axum::headers::Error::invalid())?
1174            .parse()
1175            .map_err(|_| axum::headers::Error::invalid())?;
1176        Ok(Self(version))
1177    }
1178
1179    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1180        values.extend([self.0.to_string().parse().unwrap()]);
1181    }
1182}
1183
1184pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1185    Router::new()
1186        .route("/rpc", get(handle_websocket_request))
1187        .layer(
1188            ServiceBuilder::new()
1189                .layer(Extension(server.app_state.clone()))
1190                .layer(middleware::from_fn(auth::validate_header)),
1191        )
1192        .route("/metrics", get(handle_metrics))
1193        .layer(Extension(server))
1194}
1195
1196pub async fn handle_websocket_request(
1197    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1198    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1199    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1200    Extension(server): Extension<Arc<Server>>,
1201    Extension(principal): Extension<Principal>,
1202    user_agent: Option<TypedHeader<UserAgent>>,
1203    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1204    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1205    ws: WebSocketUpgrade,
1206) -> axum::response::Response {
1207    if protocol_version != rpc::PROTOCOL_VERSION {
1208        return (
1209            StatusCode::UPGRADE_REQUIRED,
1210            "client must be upgraded".to_string(),
1211        )
1212            .into_response();
1213    }
1214
1215    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1216        return (
1217            StatusCode::UPGRADE_REQUIRED,
1218            "no version header found".to_string(),
1219        )
1220            .into_response();
1221    };
1222
1223    if !version.can_collaborate() {
1224        return (
1225            StatusCode::UPGRADE_REQUIRED,
1226            "client must be upgraded".to_string(),
1227        )
1228            .into_response();
1229    }
1230
1231    let socket_address = socket_address.to_string();
1232
1233    // Acquire connection guard before WebSocket upgrade
1234    let connection_guard = match ConnectionGuard::try_acquire() {
1235        Ok(guard) => guard,
1236        Err(()) => {
1237            return (
1238                StatusCode::SERVICE_UNAVAILABLE,
1239                "Too many concurrent connections",
1240            )
1241                .into_response();
1242        }
1243    };
1244
1245    ws.on_upgrade(move |socket| {
1246        let socket = socket
1247            .map_ok(to_tungstenite_message)
1248            .err_into()
1249            .with(|message| async move { to_axum_message(message) });
1250        let connection = Connection::new(Box::pin(socket));
1251        async move {
1252            server
1253                .handle_connection(
1254                    connection,
1255                    socket_address,
1256                    principal,
1257                    version,
1258                    user_agent.map(|header| header.to_string()),
1259                    country_code_header.map(|header| header.to_string()),
1260                    system_id_header.map(|header| header.to_string()),
1261                    None,
1262                    Executor::Production,
1263                    Some(connection_guard),
1264                )
1265                .await;
1266        }
1267    })
1268}
1269
1270pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1271    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1272    let connections_metric = CONNECTIONS_METRIC
1273        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1274
1275    let connections = server
1276        .connection_pool
1277        .lock()
1278        .connections()
1279        .filter(|connection| !connection.admin)
1280        .count();
1281    connections_metric.set(connections as _);
1282
1283    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1284    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1285        register_int_gauge!(
1286            "shared_projects",
1287            "number of open projects with one or more guests"
1288        )
1289        .unwrap()
1290    });
1291
1292    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1293    shared_projects_metric.set(shared_projects as _);
1294
1295    let encoder = prometheus::TextEncoder::new();
1296    let metric_families = prometheus::gather();
1297    let encoded_metrics = encoder
1298        .encode_to_string(&metric_families)
1299        .map_err(|err| anyhow!("{err}"))?;
1300    Ok(encoded_metrics)
1301}
1302
1303#[instrument(err, skip(executor))]
1304async fn connection_lost(
1305    session: Session,
1306    mut teardown: watch::Receiver<bool>,
1307    executor: Executor,
1308) -> Result<()> {
1309    session.peer.disconnect(session.connection_id);
1310    session
1311        .connection_pool()
1312        .await
1313        .remove_connection(session.connection_id)?;
1314
1315    session
1316        .db()
1317        .await
1318        .connection_lost(session.connection_id)
1319        .await
1320        .trace_err();
1321
1322    futures::select_biased! {
1323        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1324
1325            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1326            leave_room_for_session(&session, session.connection_id).await.trace_err();
1327            leave_channel_buffers_for_session(&session)
1328                .await
1329                .trace_err();
1330
1331            if !session
1332                .connection_pool()
1333                .await
1334                .is_user_online(session.user_id())
1335            {
1336                let db = session.db().await;
1337                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1338                    room_updated(&room, &session.peer);
1339                }
1340            }
1341
1342            update_user_contacts(session.user_id(), &session).await?;
1343        },
1344        _ = teardown.changed().fuse() => {}
1345    }
1346
1347    Ok(())
1348}
1349
1350/// Acknowledges a ping from a client, used to keep the connection alive.
1351async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1352    response.send(proto::Ack {})?;
1353    Ok(())
1354}
1355
1356/// Creates a new room for calling (outside of channels)
1357async fn create_room(
1358    _request: proto::CreateRoom,
1359    response: Response<proto::CreateRoom>,
1360    session: Session,
1361) -> Result<()> {
1362    let livekit_room = nanoid::nanoid!(30);
1363
1364    let live_kit_connection_info = util::maybe!(async {
1365        let live_kit = session.app_state.livekit_client.as_ref();
1366        let live_kit = live_kit?;
1367        let user_id = session.user_id().to_string();
1368
1369        let token = live_kit
1370            .room_token(&livekit_room, &user_id.to_string())
1371            .trace_err()?;
1372
1373        Some(proto::LiveKitConnectionInfo {
1374            server_url: live_kit.url().into(),
1375            token,
1376            can_publish: true,
1377        })
1378    })
1379    .await;
1380
1381    let room = session
1382        .db()
1383        .await
1384        .create_room(session.user_id(), session.connection_id, &livekit_room)
1385        .await?;
1386
1387    response.send(proto::CreateRoomResponse {
1388        room: Some(room.clone()),
1389        live_kit_connection_info,
1390    })?;
1391
1392    update_user_contacts(session.user_id(), &session).await?;
1393    Ok(())
1394}
1395
1396/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1397async fn join_room(
1398    request: proto::JoinRoom,
1399    response: Response<proto::JoinRoom>,
1400    session: Session,
1401) -> Result<()> {
1402    let room_id = RoomId::from_proto(request.id);
1403
1404    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1405
1406    if let Some(channel_id) = channel_id {
1407        return join_channel_internal(channel_id, Box::new(response), session).await;
1408    }
1409
1410    let joined_room = {
1411        let room = session
1412            .db()
1413            .await
1414            .join_room(room_id, session.user_id(), session.connection_id)
1415            .await?;
1416        room_updated(&room.room, &session.peer);
1417        room.into_inner()
1418    };
1419
1420    for connection_id in session
1421        .connection_pool()
1422        .await
1423        .user_connection_ids(session.user_id())
1424    {
1425        session
1426            .peer
1427            .send(
1428                connection_id,
1429                proto::CallCanceled {
1430                    room_id: room_id.to_proto(),
1431                },
1432            )
1433            .trace_err();
1434    }
1435
1436    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1437    {
1438        live_kit
1439            .room_token(
1440                &joined_room.room.livekit_room,
1441                &session.user_id().to_string(),
1442            )
1443            .trace_err()
1444            .map(|token| proto::LiveKitConnectionInfo {
1445                server_url: live_kit.url().into(),
1446                token,
1447                can_publish: true,
1448            })
1449    } else {
1450        None
1451    };
1452
1453    response.send(proto::JoinRoomResponse {
1454        room: Some(joined_room.room),
1455        channel_id: None,
1456        live_kit_connection_info,
1457    })?;
1458
1459    update_user_contacts(session.user_id(), &session).await?;
1460    Ok(())
1461}
1462
1463/// Rejoin room is used to reconnect to a room after connection errors.
1464async fn rejoin_room(
1465    request: proto::RejoinRoom,
1466    response: Response<proto::RejoinRoom>,
1467    session: Session,
1468) -> Result<()> {
1469    let room;
1470    let channel;
1471    {
1472        let mut rejoined_room = session
1473            .db()
1474            .await
1475            .rejoin_room(request, session.user_id(), session.connection_id)
1476            .await?;
1477
1478        response.send(proto::RejoinRoomResponse {
1479            room: Some(rejoined_room.room.clone()),
1480            reshared_projects: rejoined_room
1481                .reshared_projects
1482                .iter()
1483                .map(|project| proto::ResharedProject {
1484                    id: project.id.to_proto(),
1485                    collaborators: project
1486                        .collaborators
1487                        .iter()
1488                        .map(|collaborator| collaborator.to_proto())
1489                        .collect(),
1490                })
1491                .collect(),
1492            rejoined_projects: rejoined_room
1493                .rejoined_projects
1494                .iter()
1495                .map(|rejoined_project| rejoined_project.to_proto())
1496                .collect(),
1497        })?;
1498        room_updated(&rejoined_room.room, &session.peer);
1499
1500        for project in &rejoined_room.reshared_projects {
1501            for collaborator in &project.collaborators {
1502                session
1503                    .peer
1504                    .send(
1505                        collaborator.connection_id,
1506                        proto::UpdateProjectCollaborator {
1507                            project_id: project.id.to_proto(),
1508                            old_peer_id: Some(project.old_connection_id.into()),
1509                            new_peer_id: Some(session.connection_id.into()),
1510                        },
1511                    )
1512                    .trace_err();
1513            }
1514
1515            broadcast(
1516                Some(session.connection_id),
1517                project
1518                    .collaborators
1519                    .iter()
1520                    .map(|collaborator| collaborator.connection_id),
1521                |connection_id| {
1522                    session.peer.forward_send(
1523                        session.connection_id,
1524                        connection_id,
1525                        proto::UpdateProject {
1526                            project_id: project.id.to_proto(),
1527                            worktrees: project.worktrees.clone(),
1528                        },
1529                    )
1530                },
1531            );
1532        }
1533
1534        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1535
1536        let rejoined_room = rejoined_room.into_inner();
1537
1538        room = rejoined_room.room;
1539        channel = rejoined_room.channel;
1540    }
1541
1542    if let Some(channel) = channel {
1543        channel_updated(
1544            &channel,
1545            &room,
1546            &session.peer,
1547            &*session.connection_pool().await,
1548        );
1549    }
1550
1551    update_user_contacts(session.user_id(), &session).await?;
1552    Ok(())
1553}
1554
1555fn notify_rejoined_projects(
1556    rejoined_projects: &mut Vec<RejoinedProject>,
1557    session: &Session,
1558) -> Result<()> {
1559    for project in rejoined_projects.iter() {
1560        for collaborator in &project.collaborators {
1561            session
1562                .peer
1563                .send(
1564                    collaborator.connection_id,
1565                    proto::UpdateProjectCollaborator {
1566                        project_id: project.id.to_proto(),
1567                        old_peer_id: Some(project.old_connection_id.into()),
1568                        new_peer_id: Some(session.connection_id.into()),
1569                    },
1570                )
1571                .trace_err();
1572        }
1573    }
1574
1575    for project in rejoined_projects {
1576        for worktree in mem::take(&mut project.worktrees) {
1577            // Stream this worktree's entries.
1578            let message = proto::UpdateWorktree {
1579                project_id: project.id.to_proto(),
1580                worktree_id: worktree.id,
1581                abs_path: worktree.abs_path.clone(),
1582                root_name: worktree.root_name,
1583                updated_entries: worktree.updated_entries,
1584                removed_entries: worktree.removed_entries,
1585                scan_id: worktree.scan_id,
1586                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1587                updated_repositories: worktree.updated_repositories,
1588                removed_repositories: worktree.removed_repositories,
1589            };
1590            for update in proto::split_worktree_update(message) {
1591                session.peer.send(session.connection_id, update)?;
1592            }
1593
1594            // Stream this worktree's diagnostics.
1595            for summary in worktree.diagnostic_summaries {
1596                session.peer.send(
1597                    session.connection_id,
1598                    proto::UpdateDiagnosticSummary {
1599                        project_id: project.id.to_proto(),
1600                        worktree_id: worktree.id,
1601                        summary: Some(summary),
1602                    },
1603                )?;
1604            }
1605
1606            for settings_file in worktree.settings_files {
1607                session.peer.send(
1608                    session.connection_id,
1609                    proto::UpdateWorktreeSettings {
1610                        project_id: project.id.to_proto(),
1611                        worktree_id: worktree.id,
1612                        path: settings_file.path,
1613                        content: Some(settings_file.content),
1614                        kind: Some(settings_file.kind.to_proto().into()),
1615                    },
1616                )?;
1617            }
1618        }
1619
1620        for repository in mem::take(&mut project.updated_repositories) {
1621            for update in split_repository_update(repository) {
1622                session.peer.send(session.connection_id, update)?;
1623            }
1624        }
1625
1626        for id in mem::take(&mut project.removed_repositories) {
1627            session.peer.send(
1628                session.connection_id,
1629                proto::RemoveRepository {
1630                    project_id: project.id.to_proto(),
1631                    id,
1632                },
1633            )?;
1634        }
1635    }
1636
1637    Ok(())
1638}
1639
1640/// leave room disconnects from the room.
1641async fn leave_room(
1642    _: proto::LeaveRoom,
1643    response: Response<proto::LeaveRoom>,
1644    session: Session,
1645) -> Result<()> {
1646    leave_room_for_session(&session, session.connection_id).await?;
1647    response.send(proto::Ack {})?;
1648    Ok(())
1649}
1650
1651/// Updates the permissions of someone else in the room.
1652async fn set_room_participant_role(
1653    request: proto::SetRoomParticipantRole,
1654    response: Response<proto::SetRoomParticipantRole>,
1655    session: Session,
1656) -> Result<()> {
1657    let user_id = UserId::from_proto(request.user_id);
1658    let role = ChannelRole::from(request.role());
1659
1660    let (livekit_room, can_publish) = {
1661        let room = session
1662            .db()
1663            .await
1664            .set_room_participant_role(
1665                session.user_id(),
1666                RoomId::from_proto(request.room_id),
1667                user_id,
1668                role,
1669            )
1670            .await?;
1671
1672        let livekit_room = room.livekit_room.clone();
1673        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1674        room_updated(&room, &session.peer);
1675        (livekit_room, can_publish)
1676    };
1677
1678    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1679        live_kit
1680            .update_participant(
1681                livekit_room.clone(),
1682                request.user_id.to_string(),
1683                livekit_api::proto::ParticipantPermission {
1684                    can_subscribe: true,
1685                    can_publish,
1686                    can_publish_data: can_publish,
1687                    hidden: false,
1688                    recorder: false,
1689                },
1690            )
1691            .await
1692            .trace_err();
1693    }
1694
1695    response.send(proto::Ack {})?;
1696    Ok(())
1697}
1698
1699/// Call someone else into the current room
1700async fn call(
1701    request: proto::Call,
1702    response: Response<proto::Call>,
1703    session: Session,
1704) -> Result<()> {
1705    let room_id = RoomId::from_proto(request.room_id);
1706    let calling_user_id = session.user_id();
1707    let calling_connection_id = session.connection_id;
1708    let called_user_id = UserId::from_proto(request.called_user_id);
1709    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1710    if !session
1711        .db()
1712        .await
1713        .has_contact(calling_user_id, called_user_id)
1714        .await?
1715    {
1716        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1717    }
1718
1719    let incoming_call = {
1720        let (room, incoming_call) = &mut *session
1721            .db()
1722            .await
1723            .call(
1724                room_id,
1725                calling_user_id,
1726                calling_connection_id,
1727                called_user_id,
1728                initial_project_id,
1729            )
1730            .await?;
1731        room_updated(room, &session.peer);
1732        mem::take(incoming_call)
1733    };
1734    update_user_contacts(called_user_id, &session).await?;
1735
1736    let mut calls = session
1737        .connection_pool()
1738        .await
1739        .user_connection_ids(called_user_id)
1740        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1741        .collect::<FuturesUnordered<_>>();
1742
1743    while let Some(call_response) = calls.next().await {
1744        match call_response.as_ref() {
1745            Ok(_) => {
1746                response.send(proto::Ack {})?;
1747                return Ok(());
1748            }
1749            Err(_) => {
1750                call_response.trace_err();
1751            }
1752        }
1753    }
1754
1755    {
1756        let room = session
1757            .db()
1758            .await
1759            .call_failed(room_id, called_user_id)
1760            .await?;
1761        room_updated(&room, &session.peer);
1762    }
1763    update_user_contacts(called_user_id, &session).await?;
1764
1765    Err(anyhow!("failed to ring user"))?
1766}
1767
1768/// Cancel an outgoing call.
1769async fn cancel_call(
1770    request: proto::CancelCall,
1771    response: Response<proto::CancelCall>,
1772    session: Session,
1773) -> Result<()> {
1774    let called_user_id = UserId::from_proto(request.called_user_id);
1775    let room_id = RoomId::from_proto(request.room_id);
1776    {
1777        let room = session
1778            .db()
1779            .await
1780            .cancel_call(room_id, session.connection_id, called_user_id)
1781            .await?;
1782        room_updated(&room, &session.peer);
1783    }
1784
1785    for connection_id in session
1786        .connection_pool()
1787        .await
1788        .user_connection_ids(called_user_id)
1789    {
1790        session
1791            .peer
1792            .send(
1793                connection_id,
1794                proto::CallCanceled {
1795                    room_id: room_id.to_proto(),
1796                },
1797            )
1798            .trace_err();
1799    }
1800    response.send(proto::Ack {})?;
1801
1802    update_user_contacts(called_user_id, &session).await?;
1803    Ok(())
1804}
1805
1806/// Decline an incoming call.
1807async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1808    let room_id = RoomId::from_proto(message.room_id);
1809    {
1810        let room = session
1811            .db()
1812            .await
1813            .decline_call(Some(room_id), session.user_id())
1814            .await?
1815            .context("declining call")?;
1816        room_updated(&room, &session.peer);
1817    }
1818
1819    for connection_id in session
1820        .connection_pool()
1821        .await
1822        .user_connection_ids(session.user_id())
1823    {
1824        session
1825            .peer
1826            .send(
1827                connection_id,
1828                proto::CallCanceled {
1829                    room_id: room_id.to_proto(),
1830                },
1831            )
1832            .trace_err();
1833    }
1834    update_user_contacts(session.user_id(), &session).await?;
1835    Ok(())
1836}
1837
1838/// Updates other participants in the room with your current location.
1839async fn update_participant_location(
1840    request: proto::UpdateParticipantLocation,
1841    response: Response<proto::UpdateParticipantLocation>,
1842    session: Session,
1843) -> Result<()> {
1844    let room_id = RoomId::from_proto(request.room_id);
1845    let location = request.location.context("invalid location")?;
1846
1847    let db = session.db().await;
1848    let room = db
1849        .update_room_participant_location(room_id, session.connection_id, location)
1850        .await?;
1851
1852    room_updated(&room, &session.peer);
1853    response.send(proto::Ack {})?;
1854    Ok(())
1855}
1856
1857/// Share a project into the room.
1858async fn share_project(
1859    request: proto::ShareProject,
1860    response: Response<proto::ShareProject>,
1861    session: Session,
1862) -> Result<()> {
1863    let (project_id, room) = &*session
1864        .db()
1865        .await
1866        .share_project(
1867            RoomId::from_proto(request.room_id),
1868            session.connection_id,
1869            &request.worktrees,
1870            request.is_ssh_project,
1871        )
1872        .await?;
1873    response.send(proto::ShareProjectResponse {
1874        project_id: project_id.to_proto(),
1875    })?;
1876    room_updated(room, &session.peer);
1877
1878    Ok(())
1879}
1880
1881/// Unshare a project from the room.
1882async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1883    let project_id = ProjectId::from_proto(message.project_id);
1884    unshare_project_internal(project_id, session.connection_id, &session).await
1885}
1886
1887async fn unshare_project_internal(
1888    project_id: ProjectId,
1889    connection_id: ConnectionId,
1890    session: &Session,
1891) -> Result<()> {
1892    let delete = {
1893        let room_guard = session
1894            .db()
1895            .await
1896            .unshare_project(project_id, connection_id)
1897            .await?;
1898
1899        let (delete, room, guest_connection_ids) = &*room_guard;
1900
1901        let message = proto::UnshareProject {
1902            project_id: project_id.to_proto(),
1903        };
1904
1905        broadcast(
1906            Some(connection_id),
1907            guest_connection_ids.iter().copied(),
1908            |conn_id| session.peer.send(conn_id, message.clone()),
1909        );
1910        if let Some(room) = room {
1911            room_updated(room, &session.peer);
1912        }
1913
1914        *delete
1915    };
1916
1917    if delete {
1918        let db = session.db().await;
1919        db.delete_project(project_id).await?;
1920    }
1921
1922    Ok(())
1923}
1924
1925/// Join someone elses shared project.
1926async fn join_project(
1927    request: proto::JoinProject,
1928    response: Response<proto::JoinProject>,
1929    session: Session,
1930) -> Result<()> {
1931    let project_id = ProjectId::from_proto(request.project_id);
1932
1933    tracing::info!(%project_id, "join project");
1934
1935    let db = session.db().await;
1936    let (project, replica_id) = &mut *db
1937        .join_project(
1938            project_id,
1939            session.connection_id,
1940            session.user_id(),
1941            request.committer_name.clone(),
1942            request.committer_email.clone(),
1943        )
1944        .await?;
1945    drop(db);
1946    tracing::info!(%project_id, "join remote project");
1947    let collaborators = project
1948        .collaborators
1949        .iter()
1950        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1951        .map(|collaborator| collaborator.to_proto())
1952        .collect::<Vec<_>>();
1953    let project_id = project.id;
1954    let guest_user_id = session.user_id();
1955
1956    let worktrees = project
1957        .worktrees
1958        .iter()
1959        .map(|(id, worktree)| proto::WorktreeMetadata {
1960            id: *id,
1961            root_name: worktree.root_name.clone(),
1962            visible: worktree.visible,
1963            abs_path: worktree.abs_path.clone(),
1964        })
1965        .collect::<Vec<_>>();
1966
1967    let add_project_collaborator = proto::AddProjectCollaborator {
1968        project_id: project_id.to_proto(),
1969        collaborator: Some(proto::Collaborator {
1970            peer_id: Some(session.connection_id.into()),
1971            replica_id: replica_id.0 as u32,
1972            user_id: guest_user_id.to_proto(),
1973            is_host: false,
1974            committer_name: request.committer_name.clone(),
1975            committer_email: request.committer_email.clone(),
1976        }),
1977    };
1978
1979    for collaborator in &collaborators {
1980        session
1981            .peer
1982            .send(
1983                collaborator.peer_id.unwrap().into(),
1984                add_project_collaborator.clone(),
1985            )
1986            .trace_err();
1987    }
1988
1989    // First, we send the metadata associated with each worktree.
1990    let (language_servers, language_server_capabilities) = project
1991        .language_servers
1992        .clone()
1993        .into_iter()
1994        .map(|server| (server.server, server.capabilities))
1995        .unzip();
1996    response.send(proto::JoinProjectResponse {
1997        project_id: project.id.0 as u64,
1998        worktrees: worktrees.clone(),
1999        replica_id: replica_id.0 as u32,
2000        collaborators: collaborators.clone(),
2001        language_servers,
2002        language_server_capabilities,
2003        role: project.role.into(),
2004    })?;
2005
2006    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2007        // Stream this worktree's entries.
2008        let message = proto::UpdateWorktree {
2009            project_id: project_id.to_proto(),
2010            worktree_id,
2011            abs_path: worktree.abs_path.clone(),
2012            root_name: worktree.root_name,
2013            updated_entries: worktree.entries,
2014            removed_entries: Default::default(),
2015            scan_id: worktree.scan_id,
2016            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2017            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2018            removed_repositories: Default::default(),
2019        };
2020        for update in proto::split_worktree_update(message) {
2021            session.peer.send(session.connection_id, update.clone())?;
2022        }
2023
2024        // Stream this worktree's diagnostics.
2025        for summary in worktree.diagnostic_summaries {
2026            session.peer.send(
2027                session.connection_id,
2028                proto::UpdateDiagnosticSummary {
2029                    project_id: project_id.to_proto(),
2030                    worktree_id: worktree.id,
2031                    summary: Some(summary),
2032                },
2033            )?;
2034        }
2035
2036        for settings_file in worktree.settings_files {
2037            session.peer.send(
2038                session.connection_id,
2039                proto::UpdateWorktreeSettings {
2040                    project_id: project_id.to_proto(),
2041                    worktree_id: worktree.id,
2042                    path: settings_file.path,
2043                    content: Some(settings_file.content),
2044                    kind: Some(settings_file.kind.to_proto() as i32),
2045                },
2046            )?;
2047        }
2048    }
2049
2050    for repository in mem::take(&mut project.repositories) {
2051        for update in split_repository_update(repository) {
2052            session.peer.send(session.connection_id, update)?;
2053        }
2054    }
2055
2056    for language_server in &project.language_servers {
2057        session.peer.send(
2058            session.connection_id,
2059            proto::UpdateLanguageServer {
2060                project_id: project_id.to_proto(),
2061                server_name: Some(language_server.server.name.clone()),
2062                language_server_id: language_server.server.id,
2063                variant: Some(
2064                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2065                        proto::LspDiskBasedDiagnosticsUpdated {},
2066                    ),
2067                ),
2068            },
2069        )?;
2070    }
2071
2072    Ok(())
2073}
2074
2075/// Leave someone elses shared project.
2076async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2077    let sender_id = session.connection_id;
2078    let project_id = ProjectId::from_proto(request.project_id);
2079    let db = session.db().await;
2080
2081    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2082    tracing::info!(
2083        %project_id,
2084        "leave project"
2085    );
2086
2087    project_left(project, &session);
2088    if let Some(room) = room {
2089        room_updated(room, &session.peer);
2090    }
2091
2092    Ok(())
2093}
2094
2095/// Updates other participants with changes to the project
2096async fn update_project(
2097    request: proto::UpdateProject,
2098    response: Response<proto::UpdateProject>,
2099    session: Session,
2100) -> Result<()> {
2101    let project_id = ProjectId::from_proto(request.project_id);
2102    let (room, guest_connection_ids) = &*session
2103        .db()
2104        .await
2105        .update_project(project_id, session.connection_id, &request.worktrees)
2106        .await?;
2107    broadcast(
2108        Some(session.connection_id),
2109        guest_connection_ids.iter().copied(),
2110        |connection_id| {
2111            session
2112                .peer
2113                .forward_send(session.connection_id, connection_id, request.clone())
2114        },
2115    );
2116    if let Some(room) = room {
2117        room_updated(room, &session.peer);
2118    }
2119    response.send(proto::Ack {})?;
2120
2121    Ok(())
2122}
2123
2124/// Updates other participants with changes to the worktree
2125async fn update_worktree(
2126    request: proto::UpdateWorktree,
2127    response: Response<proto::UpdateWorktree>,
2128    session: Session,
2129) -> Result<()> {
2130    let guest_connection_ids = session
2131        .db()
2132        .await
2133        .update_worktree(&request, session.connection_id)
2134        .await?;
2135
2136    broadcast(
2137        Some(session.connection_id),
2138        guest_connection_ids.iter().copied(),
2139        |connection_id| {
2140            session
2141                .peer
2142                .forward_send(session.connection_id, connection_id, request.clone())
2143        },
2144    );
2145    response.send(proto::Ack {})?;
2146    Ok(())
2147}
2148
2149async fn update_repository(
2150    request: proto::UpdateRepository,
2151    response: Response<proto::UpdateRepository>,
2152    session: Session,
2153) -> Result<()> {
2154    let guest_connection_ids = session
2155        .db()
2156        .await
2157        .update_repository(&request, session.connection_id)
2158        .await?;
2159
2160    broadcast(
2161        Some(session.connection_id),
2162        guest_connection_ids.iter().copied(),
2163        |connection_id| {
2164            session
2165                .peer
2166                .forward_send(session.connection_id, connection_id, request.clone())
2167        },
2168    );
2169    response.send(proto::Ack {})?;
2170    Ok(())
2171}
2172
2173async fn remove_repository(
2174    request: proto::RemoveRepository,
2175    response: Response<proto::RemoveRepository>,
2176    session: Session,
2177) -> Result<()> {
2178    let guest_connection_ids = session
2179        .db()
2180        .await
2181        .remove_repository(&request, session.connection_id)
2182        .await?;
2183
2184    broadcast(
2185        Some(session.connection_id),
2186        guest_connection_ids.iter().copied(),
2187        |connection_id| {
2188            session
2189                .peer
2190                .forward_send(session.connection_id, connection_id, request.clone())
2191        },
2192    );
2193    response.send(proto::Ack {})?;
2194    Ok(())
2195}
2196
2197/// Updates other participants with changes to the diagnostics
2198async fn update_diagnostic_summary(
2199    message: proto::UpdateDiagnosticSummary,
2200    session: Session,
2201) -> Result<()> {
2202    let guest_connection_ids = session
2203        .db()
2204        .await
2205        .update_diagnostic_summary(&message, session.connection_id)
2206        .await?;
2207
2208    broadcast(
2209        Some(session.connection_id),
2210        guest_connection_ids.iter().copied(),
2211        |connection_id| {
2212            session
2213                .peer
2214                .forward_send(session.connection_id, connection_id, message.clone())
2215        },
2216    );
2217
2218    Ok(())
2219}
2220
2221/// Updates other participants with changes to the worktree settings
2222async fn update_worktree_settings(
2223    message: proto::UpdateWorktreeSettings,
2224    session: Session,
2225) -> Result<()> {
2226    let guest_connection_ids = session
2227        .db()
2228        .await
2229        .update_worktree_settings(&message, session.connection_id)
2230        .await?;
2231
2232    broadcast(
2233        Some(session.connection_id),
2234        guest_connection_ids.iter().copied(),
2235        |connection_id| {
2236            session
2237                .peer
2238                .forward_send(session.connection_id, connection_id, message.clone())
2239        },
2240    );
2241
2242    Ok(())
2243}
2244
2245/// Notify other participants that a language server has started.
2246async fn start_language_server(
2247    request: proto::StartLanguageServer,
2248    session: Session,
2249) -> Result<()> {
2250    let guest_connection_ids = session
2251        .db()
2252        .await
2253        .start_language_server(&request, session.connection_id)
2254        .await?;
2255
2256    broadcast(
2257        Some(session.connection_id),
2258        guest_connection_ids.iter().copied(),
2259        |connection_id| {
2260            session
2261                .peer
2262                .forward_send(session.connection_id, connection_id, request.clone())
2263        },
2264    );
2265    Ok(())
2266}
2267
2268/// Notify other participants that a language server has changed.
2269async fn update_language_server(
2270    request: proto::UpdateLanguageServer,
2271    session: Session,
2272) -> Result<()> {
2273    let project_id = ProjectId::from_proto(request.project_id);
2274    let db = session.db().await;
2275
2276    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2277    {
2278        if let Some(capabilities) = update.capabilities.clone() {
2279            db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2280                .await?;
2281        }
2282    }
2283
2284    let project_connection_ids = db
2285        .project_connection_ids(project_id, session.connection_id, true)
2286        .await?;
2287    broadcast(
2288        Some(session.connection_id),
2289        project_connection_ids.iter().copied(),
2290        |connection_id| {
2291            session
2292                .peer
2293                .forward_send(session.connection_id, connection_id, request.clone())
2294        },
2295    );
2296    Ok(())
2297}
2298
2299/// forward a project request to the host. These requests should be read only
2300/// as guests are allowed to send them.
2301async fn forward_read_only_project_request<T>(
2302    request: T,
2303    response: Response<T>,
2304    session: Session,
2305) -> Result<()>
2306where
2307    T: EntityMessage + RequestMessage,
2308{
2309    let project_id = ProjectId::from_proto(request.remote_entity_id());
2310    let host_connection_id = session
2311        .db()
2312        .await
2313        .host_for_read_only_project_request(project_id, session.connection_id)
2314        .await?;
2315    let payload = session
2316        .peer
2317        .forward_request(session.connection_id, host_connection_id, request)
2318        .await?;
2319    response.send(payload)?;
2320    Ok(())
2321}
2322
2323/// forward a project request to the host. These requests are disallowed
2324/// for guests.
2325async fn forward_mutating_project_request<T>(
2326    request: T,
2327    response: Response<T>,
2328    session: Session,
2329) -> Result<()>
2330where
2331    T: EntityMessage + RequestMessage,
2332{
2333    let project_id = ProjectId::from_proto(request.remote_entity_id());
2334
2335    let host_connection_id = session
2336        .db()
2337        .await
2338        .host_for_mutating_project_request(project_id, session.connection_id)
2339        .await?;
2340    let payload = session
2341        .peer
2342        .forward_request(session.connection_id, host_connection_id, request)
2343        .await?;
2344    response.send(payload)?;
2345    Ok(())
2346}
2347
2348async fn multi_lsp_query(
2349    request: MultiLspQuery,
2350    response: Response<MultiLspQuery>,
2351    session: Session,
2352) -> Result<()> {
2353    tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2354    tracing::info!("multi_lsp_query message received");
2355    forward_mutating_project_request(request, response, session).await
2356}
2357
2358/// Notify other participants that a new buffer has been created
2359async fn create_buffer_for_peer(
2360    request: proto::CreateBufferForPeer,
2361    session: Session,
2362) -> Result<()> {
2363    session
2364        .db()
2365        .await
2366        .check_user_is_project_host(
2367            ProjectId::from_proto(request.project_id),
2368            session.connection_id,
2369        )
2370        .await?;
2371    let peer_id = request.peer_id.context("invalid peer id")?;
2372    session
2373        .peer
2374        .forward_send(session.connection_id, peer_id.into(), request)?;
2375    Ok(())
2376}
2377
2378/// Notify other participants that a buffer has been updated. This is
2379/// allowed for guests as long as the update is limited to selections.
2380async fn update_buffer(
2381    request: proto::UpdateBuffer,
2382    response: Response<proto::UpdateBuffer>,
2383    session: Session,
2384) -> Result<()> {
2385    let project_id = ProjectId::from_proto(request.project_id);
2386    let mut capability = Capability::ReadOnly;
2387
2388    for op in request.operations.iter() {
2389        match op.variant {
2390            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2391            Some(_) => capability = Capability::ReadWrite,
2392        }
2393    }
2394
2395    let host = {
2396        let guard = session
2397            .db()
2398            .await
2399            .connections_for_buffer_update(project_id, session.connection_id, capability)
2400            .await?;
2401
2402        let (host, guests) = &*guard;
2403
2404        broadcast(
2405            Some(session.connection_id),
2406            guests.clone(),
2407            |connection_id| {
2408                session
2409                    .peer
2410                    .forward_send(session.connection_id, connection_id, request.clone())
2411            },
2412        );
2413
2414        *host
2415    };
2416
2417    if host != session.connection_id {
2418        session
2419            .peer
2420            .forward_request(session.connection_id, host, request.clone())
2421            .await?;
2422    }
2423
2424    response.send(proto::Ack {})?;
2425    Ok(())
2426}
2427
2428async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2429    let project_id = ProjectId::from_proto(message.project_id);
2430
2431    let operation = message.operation.as_ref().context("invalid operation")?;
2432    let capability = match operation.variant.as_ref() {
2433        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2434            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2435                match buffer_op.variant {
2436                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2437                        Capability::ReadOnly
2438                    }
2439                    _ => Capability::ReadWrite,
2440                }
2441            } else {
2442                Capability::ReadWrite
2443            }
2444        }
2445        Some(_) => Capability::ReadWrite,
2446        None => Capability::ReadOnly,
2447    };
2448
2449    let guard = session
2450        .db()
2451        .await
2452        .connections_for_buffer_update(project_id, session.connection_id, capability)
2453        .await?;
2454
2455    let (host, guests) = &*guard;
2456
2457    broadcast(
2458        Some(session.connection_id),
2459        guests.iter().chain([host]).copied(),
2460        |connection_id| {
2461            session
2462                .peer
2463                .forward_send(session.connection_id, connection_id, message.clone())
2464        },
2465    );
2466
2467    Ok(())
2468}
2469
2470/// Notify other participants that a project has been updated.
2471async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2472    request: T,
2473    session: Session,
2474) -> Result<()> {
2475    let project_id = ProjectId::from_proto(request.remote_entity_id());
2476    let project_connection_ids = session
2477        .db()
2478        .await
2479        .project_connection_ids(project_id, session.connection_id, false)
2480        .await?;
2481
2482    broadcast(
2483        Some(session.connection_id),
2484        project_connection_ids.iter().copied(),
2485        |connection_id| {
2486            session
2487                .peer
2488                .forward_send(session.connection_id, connection_id, request.clone())
2489        },
2490    );
2491    Ok(())
2492}
2493
2494/// Start following another user in a call.
2495async fn follow(
2496    request: proto::Follow,
2497    response: Response<proto::Follow>,
2498    session: Session,
2499) -> Result<()> {
2500    let room_id = RoomId::from_proto(request.room_id);
2501    let project_id = request.project_id.map(ProjectId::from_proto);
2502    let leader_id = request.leader_id.context("invalid leader id")?.into();
2503    let follower_id = session.connection_id;
2504
2505    session
2506        .db()
2507        .await
2508        .check_room_participants(room_id, leader_id, session.connection_id)
2509        .await?;
2510
2511    let response_payload = session
2512        .peer
2513        .forward_request(session.connection_id, leader_id, request)
2514        .await?;
2515    response.send(response_payload)?;
2516
2517    if let Some(project_id) = project_id {
2518        let room = session
2519            .db()
2520            .await
2521            .follow(room_id, project_id, leader_id, follower_id)
2522            .await?;
2523        room_updated(&room, &session.peer);
2524    }
2525
2526    Ok(())
2527}
2528
2529/// Stop following another user in a call.
2530async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2531    let room_id = RoomId::from_proto(request.room_id);
2532    let project_id = request.project_id.map(ProjectId::from_proto);
2533    let leader_id = request.leader_id.context("invalid leader id")?.into();
2534    let follower_id = session.connection_id;
2535
2536    session
2537        .db()
2538        .await
2539        .check_room_participants(room_id, leader_id, session.connection_id)
2540        .await?;
2541
2542    session
2543        .peer
2544        .forward_send(session.connection_id, leader_id, request)?;
2545
2546    if let Some(project_id) = project_id {
2547        let room = session
2548            .db()
2549            .await
2550            .unfollow(room_id, project_id, leader_id, follower_id)
2551            .await?;
2552        room_updated(&room, &session.peer);
2553    }
2554
2555    Ok(())
2556}
2557
2558/// Notify everyone following you of your current location.
2559async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2560    let room_id = RoomId::from_proto(request.room_id);
2561    let database = session.db.lock().await;
2562
2563    let connection_ids = if let Some(project_id) = request.project_id {
2564        let project_id = ProjectId::from_proto(project_id);
2565        database
2566            .project_connection_ids(project_id, session.connection_id, true)
2567            .await?
2568    } else {
2569        database
2570            .room_connection_ids(room_id, session.connection_id)
2571            .await?
2572    };
2573
2574    // For now, don't send view update messages back to that view's current leader.
2575    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2576        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2577        _ => None,
2578    });
2579
2580    for connection_id in connection_ids.iter().cloned() {
2581        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2582            session
2583                .peer
2584                .forward_send(session.connection_id, connection_id, request.clone())?;
2585        }
2586    }
2587    Ok(())
2588}
2589
2590/// Get public data about users.
2591async fn get_users(
2592    request: proto::GetUsers,
2593    response: Response<proto::GetUsers>,
2594    session: Session,
2595) -> Result<()> {
2596    let user_ids = request
2597        .user_ids
2598        .into_iter()
2599        .map(UserId::from_proto)
2600        .collect();
2601    let users = session
2602        .db()
2603        .await
2604        .get_users_by_ids(user_ids)
2605        .await?
2606        .into_iter()
2607        .map(|user| proto::User {
2608            id: user.id.to_proto(),
2609            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2610            github_login: user.github_login,
2611            name: user.name,
2612        })
2613        .collect();
2614    response.send(proto::UsersResponse { users })?;
2615    Ok(())
2616}
2617
2618/// Search for users (to invite) buy Github login
2619async fn fuzzy_search_users(
2620    request: proto::FuzzySearchUsers,
2621    response: Response<proto::FuzzySearchUsers>,
2622    session: Session,
2623) -> Result<()> {
2624    let query = request.query;
2625    let users = match query.len() {
2626        0 => vec![],
2627        1 | 2 => session
2628            .db()
2629            .await
2630            .get_user_by_github_login(&query)
2631            .await?
2632            .into_iter()
2633            .collect(),
2634        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2635    };
2636    let users = users
2637        .into_iter()
2638        .filter(|user| user.id != session.user_id())
2639        .map(|user| proto::User {
2640            id: user.id.to_proto(),
2641            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2642            github_login: user.github_login,
2643            name: user.name,
2644        })
2645        .collect();
2646    response.send(proto::UsersResponse { users })?;
2647    Ok(())
2648}
2649
2650/// Send a contact request to another user.
2651async fn request_contact(
2652    request: proto::RequestContact,
2653    response: Response<proto::RequestContact>,
2654    session: Session,
2655) -> Result<()> {
2656    let requester_id = session.user_id();
2657    let responder_id = UserId::from_proto(request.responder_id);
2658    if requester_id == responder_id {
2659        return Err(anyhow!("cannot add yourself as a contact"))?;
2660    }
2661
2662    let notifications = session
2663        .db()
2664        .await
2665        .send_contact_request(requester_id, responder_id)
2666        .await?;
2667
2668    // Update outgoing contact requests of requester
2669    let mut update = proto::UpdateContacts::default();
2670    update.outgoing_requests.push(responder_id.to_proto());
2671    for connection_id in session
2672        .connection_pool()
2673        .await
2674        .user_connection_ids(requester_id)
2675    {
2676        session.peer.send(connection_id, update.clone())?;
2677    }
2678
2679    // Update incoming contact requests of responder
2680    let mut update = proto::UpdateContacts::default();
2681    update
2682        .incoming_requests
2683        .push(proto::IncomingContactRequest {
2684            requester_id: requester_id.to_proto(),
2685        });
2686    let connection_pool = session.connection_pool().await;
2687    for connection_id in connection_pool.user_connection_ids(responder_id) {
2688        session.peer.send(connection_id, update.clone())?;
2689    }
2690
2691    send_notifications(&connection_pool, &session.peer, notifications);
2692
2693    response.send(proto::Ack {})?;
2694    Ok(())
2695}
2696
2697/// Accept or decline a contact request
2698async fn respond_to_contact_request(
2699    request: proto::RespondToContactRequest,
2700    response: Response<proto::RespondToContactRequest>,
2701    session: Session,
2702) -> Result<()> {
2703    let responder_id = session.user_id();
2704    let requester_id = UserId::from_proto(request.requester_id);
2705    let db = session.db().await;
2706    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2707        db.dismiss_contact_notification(responder_id, requester_id)
2708            .await?;
2709    } else {
2710        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2711
2712        let notifications = db
2713            .respond_to_contact_request(responder_id, requester_id, accept)
2714            .await?;
2715        let requester_busy = db.is_user_busy(requester_id).await?;
2716        let responder_busy = db.is_user_busy(responder_id).await?;
2717
2718        let pool = session.connection_pool().await;
2719        // Update responder with new contact
2720        let mut update = proto::UpdateContacts::default();
2721        if accept {
2722            update
2723                .contacts
2724                .push(contact_for_user(requester_id, requester_busy, &pool));
2725        }
2726        update
2727            .remove_incoming_requests
2728            .push(requester_id.to_proto());
2729        for connection_id in pool.user_connection_ids(responder_id) {
2730            session.peer.send(connection_id, update.clone())?;
2731        }
2732
2733        // Update requester with new contact
2734        let mut update = proto::UpdateContacts::default();
2735        if accept {
2736            update
2737                .contacts
2738                .push(contact_for_user(responder_id, responder_busy, &pool));
2739        }
2740        update
2741            .remove_outgoing_requests
2742            .push(responder_id.to_proto());
2743
2744        for connection_id in pool.user_connection_ids(requester_id) {
2745            session.peer.send(connection_id, update.clone())?;
2746        }
2747
2748        send_notifications(&pool, &session.peer, notifications);
2749    }
2750
2751    response.send(proto::Ack {})?;
2752    Ok(())
2753}
2754
2755/// Remove a contact.
2756async fn remove_contact(
2757    request: proto::RemoveContact,
2758    response: Response<proto::RemoveContact>,
2759    session: Session,
2760) -> Result<()> {
2761    let requester_id = session.user_id();
2762    let responder_id = UserId::from_proto(request.user_id);
2763    let db = session.db().await;
2764    let (contact_accepted, deleted_notification_id) =
2765        db.remove_contact(requester_id, responder_id).await?;
2766
2767    let pool = session.connection_pool().await;
2768    // Update outgoing contact requests of requester
2769    let mut update = proto::UpdateContacts::default();
2770    if contact_accepted {
2771        update.remove_contacts.push(responder_id.to_proto());
2772    } else {
2773        update
2774            .remove_outgoing_requests
2775            .push(responder_id.to_proto());
2776    }
2777    for connection_id in pool.user_connection_ids(requester_id) {
2778        session.peer.send(connection_id, update.clone())?;
2779    }
2780
2781    // Update incoming contact requests of responder
2782    let mut update = proto::UpdateContacts::default();
2783    if contact_accepted {
2784        update.remove_contacts.push(requester_id.to_proto());
2785    } else {
2786        update
2787            .remove_incoming_requests
2788            .push(requester_id.to_proto());
2789    }
2790    for connection_id in pool.user_connection_ids(responder_id) {
2791        session.peer.send(connection_id, update.clone())?;
2792        if let Some(notification_id) = deleted_notification_id {
2793            session.peer.send(
2794                connection_id,
2795                proto::DeleteNotification {
2796                    notification_id: notification_id.to_proto(),
2797                },
2798            )?;
2799        }
2800    }
2801
2802    response.send(proto::Ack {})?;
2803    Ok(())
2804}
2805
2806fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2807    version.0.minor() < 139
2808}
2809
2810async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2811    if is_staff {
2812        return Ok(proto::Plan::ZedPro);
2813    }
2814
2815    let subscription = db.get_active_billing_subscription(user_id).await?;
2816    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2817
2818    let plan = if let Some(subscription_kind) = subscription_kind {
2819        match subscription_kind {
2820            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2821            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2822            SubscriptionKind::ZedFree => proto::Plan::Free,
2823        }
2824    } else {
2825        proto::Plan::Free
2826    };
2827
2828    Ok(plan)
2829}
2830
2831async fn make_update_user_plan_message(
2832    user: &User,
2833    is_staff: bool,
2834    db: &Arc<Database>,
2835    llm_db: Option<Arc<LlmDatabase>>,
2836) -> Result<proto::UpdateUserPlan> {
2837    let feature_flags = db.get_user_flags(user.id).await?;
2838    let plan = current_plan(db, user.id, is_staff).await?;
2839    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2840    let billing_preferences = db.get_billing_preferences(user.id).await?;
2841
2842    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2843        let subscription = db.get_active_billing_subscription(user.id).await?;
2844
2845        let subscription_period =
2846            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2847
2848        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2849            llm_db
2850                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2851                .await?
2852        } else {
2853            None
2854        };
2855
2856        (subscription_period, usage)
2857    } else {
2858        (None, None)
2859    };
2860
2861    let bypass_account_age_check = feature_flags
2862        .iter()
2863        .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2864    let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2865        && !bypass_account_age_check
2866        && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2867
2868    Ok(proto::UpdateUserPlan {
2869        plan: plan.into(),
2870        trial_started_at: billing_customer
2871            .as_ref()
2872            .and_then(|billing_customer| billing_customer.trial_started_at)
2873            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2874        is_usage_based_billing_enabled: if is_staff {
2875            Some(true)
2876        } else {
2877            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2878        },
2879        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2880            proto::SubscriptionPeriod {
2881                started_at: started_at.timestamp() as u64,
2882                ended_at: ended_at.timestamp() as u64,
2883            }
2884        }),
2885        account_too_young: Some(account_too_young),
2886        has_overdue_invoices: billing_customer
2887            .map(|billing_customer| billing_customer.has_overdue_invoices),
2888        usage: Some(
2889            usage
2890                .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2891                .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2892        ),
2893    })
2894}
2895
2896fn model_requests_limit(
2897    plan: cloud_llm_client::Plan,
2898    feature_flags: &Vec<String>,
2899) -> cloud_llm_client::UsageLimit {
2900    match plan.model_requests_limit() {
2901        cloud_llm_client::UsageLimit::Limited(limit) => {
2902            let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2903                && feature_flags
2904                    .iter()
2905                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2906            {
2907                1_000
2908            } else {
2909                limit
2910            };
2911
2912            cloud_llm_client::UsageLimit::Limited(limit)
2913        }
2914        cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2915    }
2916}
2917
2918fn subscription_usage_to_proto(
2919    plan: proto::Plan,
2920    usage: crate::llm::db::subscription_usage::Model,
2921    feature_flags: &Vec<String>,
2922) -> proto::SubscriptionUsage {
2923    let plan = match plan {
2924        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2925        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2926        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2927    };
2928
2929    proto::SubscriptionUsage {
2930        model_requests_usage_amount: usage.model_requests as u32,
2931        model_requests_usage_limit: Some(proto::UsageLimit {
2932            variant: Some(match model_requests_limit(plan, feature_flags) {
2933                cloud_llm_client::UsageLimit::Limited(limit) => {
2934                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2935                        limit: limit as u32,
2936                    })
2937                }
2938                cloud_llm_client::UsageLimit::Unlimited => {
2939                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2940                }
2941            }),
2942        }),
2943        edit_predictions_usage_amount: usage.edit_predictions as u32,
2944        edit_predictions_usage_limit: Some(proto::UsageLimit {
2945            variant: Some(match plan.edit_predictions_limit() {
2946                cloud_llm_client::UsageLimit::Limited(limit) => {
2947                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2948                        limit: limit as u32,
2949                    })
2950                }
2951                cloud_llm_client::UsageLimit::Unlimited => {
2952                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2953                }
2954            }),
2955        }),
2956    }
2957}
2958
2959fn make_default_subscription_usage(
2960    plan: proto::Plan,
2961    feature_flags: &Vec<String>,
2962) -> proto::SubscriptionUsage {
2963    let plan = match plan {
2964        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2965        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2966        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2967    };
2968
2969    proto::SubscriptionUsage {
2970        model_requests_usage_amount: 0,
2971        model_requests_usage_limit: Some(proto::UsageLimit {
2972            variant: Some(match model_requests_limit(plan, feature_flags) {
2973                cloud_llm_client::UsageLimit::Limited(limit) => {
2974                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2975                        limit: limit as u32,
2976                    })
2977                }
2978                cloud_llm_client::UsageLimit::Unlimited => {
2979                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2980                }
2981            }),
2982        }),
2983        edit_predictions_usage_amount: 0,
2984        edit_predictions_usage_limit: Some(proto::UsageLimit {
2985            variant: Some(match plan.edit_predictions_limit() {
2986                cloud_llm_client::UsageLimit::Limited(limit) => {
2987                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2988                        limit: limit as u32,
2989                    })
2990                }
2991                cloud_llm_client::UsageLimit::Unlimited => {
2992                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2993                }
2994            }),
2995        }),
2996    }
2997}
2998
2999async fn update_user_plan(session: &Session) -> Result<()> {
3000    let db = session.db().await;
3001
3002    let update_user_plan = make_update_user_plan_message(
3003        session.principal.user(),
3004        session.is_staff(),
3005        &db.0,
3006        session.app_state.llm_db.clone(),
3007    )
3008    .await?;
3009
3010    session
3011        .peer
3012        .send(session.connection_id, update_user_plan)
3013        .trace_err();
3014
3015    Ok(())
3016}
3017
3018async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3019    subscribe_user_to_channels(session.user_id(), &session).await?;
3020    Ok(())
3021}
3022
3023async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3024    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3025    let mut pool = session.connection_pool().await;
3026    for membership in &channels_for_user.channel_memberships {
3027        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3028    }
3029    session.peer.send(
3030        session.connection_id,
3031        build_update_user_channels(&channels_for_user),
3032    )?;
3033    session.peer.send(
3034        session.connection_id,
3035        build_channels_update(channels_for_user),
3036    )?;
3037    Ok(())
3038}
3039
3040/// Creates a new channel.
3041async fn create_channel(
3042    request: proto::CreateChannel,
3043    response: Response<proto::CreateChannel>,
3044    session: Session,
3045) -> Result<()> {
3046    let db = session.db().await;
3047
3048    let parent_id = request.parent_id.map(ChannelId::from_proto);
3049    let (channel, membership) = db
3050        .create_channel(&request.name, parent_id, session.user_id())
3051        .await?;
3052
3053    let root_id = channel.root_id();
3054    let channel = Channel::from_model(channel);
3055
3056    response.send(proto::CreateChannelResponse {
3057        channel: Some(channel.to_proto()),
3058        parent_id: request.parent_id,
3059    })?;
3060
3061    let mut connection_pool = session.connection_pool().await;
3062    if let Some(membership) = membership {
3063        connection_pool.subscribe_to_channel(
3064            membership.user_id,
3065            membership.channel_id,
3066            membership.role,
3067        );
3068        let update = proto::UpdateUserChannels {
3069            channel_memberships: vec![proto::ChannelMembership {
3070                channel_id: membership.channel_id.to_proto(),
3071                role: membership.role.into(),
3072            }],
3073            ..Default::default()
3074        };
3075        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3076            session.peer.send(connection_id, update.clone())?;
3077        }
3078    }
3079
3080    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3081        if !role.can_see_channel(channel.visibility) {
3082            continue;
3083        }
3084
3085        let update = proto::UpdateChannels {
3086            channels: vec![channel.to_proto()],
3087            ..Default::default()
3088        };
3089        session.peer.send(connection_id, update.clone())?;
3090    }
3091
3092    Ok(())
3093}
3094
3095/// Delete a channel
3096async fn delete_channel(
3097    request: proto::DeleteChannel,
3098    response: Response<proto::DeleteChannel>,
3099    session: Session,
3100) -> Result<()> {
3101    let db = session.db().await;
3102
3103    let channel_id = request.channel_id;
3104    let (root_channel, removed_channels) = db
3105        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3106        .await?;
3107    response.send(proto::Ack {})?;
3108
3109    // Notify members of removed channels
3110    let mut update = proto::UpdateChannels::default();
3111    update
3112        .delete_channels
3113        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3114
3115    let connection_pool = session.connection_pool().await;
3116    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3117        session.peer.send(connection_id, update.clone())?;
3118    }
3119
3120    Ok(())
3121}
3122
3123/// Invite someone to join a channel.
3124async fn invite_channel_member(
3125    request: proto::InviteChannelMember,
3126    response: Response<proto::InviteChannelMember>,
3127    session: Session,
3128) -> Result<()> {
3129    let db = session.db().await;
3130    let channel_id = ChannelId::from_proto(request.channel_id);
3131    let invitee_id = UserId::from_proto(request.user_id);
3132    let InviteMemberResult {
3133        channel,
3134        notifications,
3135    } = db
3136        .invite_channel_member(
3137            channel_id,
3138            invitee_id,
3139            session.user_id(),
3140            request.role().into(),
3141        )
3142        .await?;
3143
3144    let update = proto::UpdateChannels {
3145        channel_invitations: vec![channel.to_proto()],
3146        ..Default::default()
3147    };
3148
3149    let connection_pool = session.connection_pool().await;
3150    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3151        session.peer.send(connection_id, update.clone())?;
3152    }
3153
3154    send_notifications(&connection_pool, &session.peer, notifications);
3155
3156    response.send(proto::Ack {})?;
3157    Ok(())
3158}
3159
3160/// remove someone from a channel
3161async fn remove_channel_member(
3162    request: proto::RemoveChannelMember,
3163    response: Response<proto::RemoveChannelMember>,
3164    session: Session,
3165) -> Result<()> {
3166    let db = session.db().await;
3167    let channel_id = ChannelId::from_proto(request.channel_id);
3168    let member_id = UserId::from_proto(request.user_id);
3169
3170    let RemoveChannelMemberResult {
3171        membership_update,
3172        notification_id,
3173    } = db
3174        .remove_channel_member(channel_id, member_id, session.user_id())
3175        .await?;
3176
3177    let mut connection_pool = session.connection_pool().await;
3178    notify_membership_updated(
3179        &mut connection_pool,
3180        membership_update,
3181        member_id,
3182        &session.peer,
3183    );
3184    for connection_id in connection_pool.user_connection_ids(member_id) {
3185        if let Some(notification_id) = notification_id {
3186            session
3187                .peer
3188                .send(
3189                    connection_id,
3190                    proto::DeleteNotification {
3191                        notification_id: notification_id.to_proto(),
3192                    },
3193                )
3194                .trace_err();
3195        }
3196    }
3197
3198    response.send(proto::Ack {})?;
3199    Ok(())
3200}
3201
3202/// Toggle the channel between public and private.
3203/// Care is taken to maintain the invariant that public channels only descend from public channels,
3204/// (though members-only channels can appear at any point in the hierarchy).
3205async fn set_channel_visibility(
3206    request: proto::SetChannelVisibility,
3207    response: Response<proto::SetChannelVisibility>,
3208    session: Session,
3209) -> Result<()> {
3210    let db = session.db().await;
3211    let channel_id = ChannelId::from_proto(request.channel_id);
3212    let visibility = request.visibility().into();
3213
3214    let channel_model = db
3215        .set_channel_visibility(channel_id, visibility, session.user_id())
3216        .await?;
3217    let root_id = channel_model.root_id();
3218    let channel = Channel::from_model(channel_model);
3219
3220    let mut connection_pool = session.connection_pool().await;
3221    for (user_id, role) in connection_pool
3222        .channel_user_ids(root_id)
3223        .collect::<Vec<_>>()
3224        .into_iter()
3225    {
3226        let update = if role.can_see_channel(channel.visibility) {
3227            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3228            proto::UpdateChannels {
3229                channels: vec![channel.to_proto()],
3230                ..Default::default()
3231            }
3232        } else {
3233            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3234            proto::UpdateChannels {
3235                delete_channels: vec![channel.id.to_proto()],
3236                ..Default::default()
3237            }
3238        };
3239
3240        for connection_id in connection_pool.user_connection_ids(user_id) {
3241            session.peer.send(connection_id, update.clone())?;
3242        }
3243    }
3244
3245    response.send(proto::Ack {})?;
3246    Ok(())
3247}
3248
3249/// Alter the role for a user in the channel.
3250async fn set_channel_member_role(
3251    request: proto::SetChannelMemberRole,
3252    response: Response<proto::SetChannelMemberRole>,
3253    session: Session,
3254) -> Result<()> {
3255    let db = session.db().await;
3256    let channel_id = ChannelId::from_proto(request.channel_id);
3257    let member_id = UserId::from_proto(request.user_id);
3258    let result = db
3259        .set_channel_member_role(
3260            channel_id,
3261            session.user_id(),
3262            member_id,
3263            request.role().into(),
3264        )
3265        .await?;
3266
3267    match result {
3268        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3269            let mut connection_pool = session.connection_pool().await;
3270            notify_membership_updated(
3271                &mut connection_pool,
3272                membership_update,
3273                member_id,
3274                &session.peer,
3275            )
3276        }
3277        db::SetMemberRoleResult::InviteUpdated(channel) => {
3278            let update = proto::UpdateChannels {
3279                channel_invitations: vec![channel.to_proto()],
3280                ..Default::default()
3281            };
3282
3283            for connection_id in session
3284                .connection_pool()
3285                .await
3286                .user_connection_ids(member_id)
3287            {
3288                session.peer.send(connection_id, update.clone())?;
3289            }
3290        }
3291    }
3292
3293    response.send(proto::Ack {})?;
3294    Ok(())
3295}
3296
3297/// Change the name of a channel
3298async fn rename_channel(
3299    request: proto::RenameChannel,
3300    response: Response<proto::RenameChannel>,
3301    session: Session,
3302) -> Result<()> {
3303    let db = session.db().await;
3304    let channel_id = ChannelId::from_proto(request.channel_id);
3305    let channel_model = db
3306        .rename_channel(channel_id, session.user_id(), &request.name)
3307        .await?;
3308    let root_id = channel_model.root_id();
3309    let channel = Channel::from_model(channel_model);
3310
3311    response.send(proto::RenameChannelResponse {
3312        channel: Some(channel.to_proto()),
3313    })?;
3314
3315    let connection_pool = session.connection_pool().await;
3316    let update = proto::UpdateChannels {
3317        channels: vec![channel.to_proto()],
3318        ..Default::default()
3319    };
3320    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3321        if role.can_see_channel(channel.visibility) {
3322            session.peer.send(connection_id, update.clone())?;
3323        }
3324    }
3325
3326    Ok(())
3327}
3328
3329/// Move a channel to a new parent.
3330async fn move_channel(
3331    request: proto::MoveChannel,
3332    response: Response<proto::MoveChannel>,
3333    session: Session,
3334) -> Result<()> {
3335    let channel_id = ChannelId::from_proto(request.channel_id);
3336    let to = ChannelId::from_proto(request.to);
3337
3338    let (root_id, channels) = session
3339        .db()
3340        .await
3341        .move_channel(channel_id, to, session.user_id())
3342        .await?;
3343
3344    let connection_pool = session.connection_pool().await;
3345    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3346        let channels = channels
3347            .iter()
3348            .filter_map(|channel| {
3349                if role.can_see_channel(channel.visibility) {
3350                    Some(channel.to_proto())
3351                } else {
3352                    None
3353                }
3354            })
3355            .collect::<Vec<_>>();
3356        if channels.is_empty() {
3357            continue;
3358        }
3359
3360        let update = proto::UpdateChannels {
3361            channels,
3362            ..Default::default()
3363        };
3364
3365        session.peer.send(connection_id, update.clone())?;
3366    }
3367
3368    response.send(Ack {})?;
3369    Ok(())
3370}
3371
3372async fn reorder_channel(
3373    request: proto::ReorderChannel,
3374    response: Response<proto::ReorderChannel>,
3375    session: Session,
3376) -> Result<()> {
3377    let channel_id = ChannelId::from_proto(request.channel_id);
3378    let direction = request.direction();
3379
3380    let updated_channels = session
3381        .db()
3382        .await
3383        .reorder_channel(channel_id, direction, session.user_id())
3384        .await?;
3385
3386    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3387        let connection_pool = session.connection_pool().await;
3388        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3389            let channels = updated_channels
3390                .iter()
3391                .filter_map(|channel| {
3392                    if role.can_see_channel(channel.visibility) {
3393                        Some(channel.to_proto())
3394                    } else {
3395                        None
3396                    }
3397                })
3398                .collect::<Vec<_>>();
3399
3400            if channels.is_empty() {
3401                continue;
3402            }
3403
3404            let update = proto::UpdateChannels {
3405                channels,
3406                ..Default::default()
3407            };
3408
3409            session.peer.send(connection_id, update.clone())?;
3410        }
3411    }
3412
3413    response.send(Ack {})?;
3414    Ok(())
3415}
3416
3417/// Get the list of channel members
3418async fn get_channel_members(
3419    request: proto::GetChannelMembers,
3420    response: Response<proto::GetChannelMembers>,
3421    session: Session,
3422) -> Result<()> {
3423    let db = session.db().await;
3424    let channel_id = ChannelId::from_proto(request.channel_id);
3425    let limit = if request.limit == 0 {
3426        u16::MAX as u64
3427    } else {
3428        request.limit
3429    };
3430    let (members, users) = db
3431        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3432        .await?;
3433    response.send(proto::GetChannelMembersResponse { members, users })?;
3434    Ok(())
3435}
3436
3437/// Accept or decline a channel invitation.
3438async fn respond_to_channel_invite(
3439    request: proto::RespondToChannelInvite,
3440    response: Response<proto::RespondToChannelInvite>,
3441    session: Session,
3442) -> Result<()> {
3443    let db = session.db().await;
3444    let channel_id = ChannelId::from_proto(request.channel_id);
3445    let RespondToChannelInvite {
3446        membership_update,
3447        notifications,
3448    } = db
3449        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3450        .await?;
3451
3452    let mut connection_pool = session.connection_pool().await;
3453    if let Some(membership_update) = membership_update {
3454        notify_membership_updated(
3455            &mut connection_pool,
3456            membership_update,
3457            session.user_id(),
3458            &session.peer,
3459        );
3460    } else {
3461        let update = proto::UpdateChannels {
3462            remove_channel_invitations: vec![channel_id.to_proto()],
3463            ..Default::default()
3464        };
3465
3466        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3467            session.peer.send(connection_id, update.clone())?;
3468        }
3469    };
3470
3471    send_notifications(&connection_pool, &session.peer, notifications);
3472
3473    response.send(proto::Ack {})?;
3474
3475    Ok(())
3476}
3477
3478/// Join the channels' room
3479async fn join_channel(
3480    request: proto::JoinChannel,
3481    response: Response<proto::JoinChannel>,
3482    session: Session,
3483) -> Result<()> {
3484    let channel_id = ChannelId::from_proto(request.channel_id);
3485    join_channel_internal(channel_id, Box::new(response), session).await
3486}
3487
3488trait JoinChannelInternalResponse {
3489    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3490}
3491impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3492    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3493        Response::<proto::JoinChannel>::send(self, result)
3494    }
3495}
3496impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3497    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3498        Response::<proto::JoinRoom>::send(self, result)
3499    }
3500}
3501
3502async fn join_channel_internal(
3503    channel_id: ChannelId,
3504    response: Box<impl JoinChannelInternalResponse>,
3505    session: Session,
3506) -> Result<()> {
3507    let joined_room = {
3508        let mut db = session.db().await;
3509        // If zed quits without leaving the room, and the user re-opens zed before the
3510        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3511        // room they were in.
3512        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3513            tracing::info!(
3514                stale_connection_id = %connection,
3515                "cleaning up stale connection",
3516            );
3517            drop(db);
3518            leave_room_for_session(&session, connection).await?;
3519            db = session.db().await;
3520        }
3521
3522        let (joined_room, membership_updated, role) = db
3523            .join_channel(channel_id, session.user_id(), session.connection_id)
3524            .await?;
3525
3526        let live_kit_connection_info =
3527            session
3528                .app_state
3529                .livekit_client
3530                .as_ref()
3531                .and_then(|live_kit| {
3532                    let (can_publish, token) = if role == ChannelRole::Guest {
3533                        (
3534                            false,
3535                            live_kit
3536                                .guest_token(
3537                                    &joined_room.room.livekit_room,
3538                                    &session.user_id().to_string(),
3539                                )
3540                                .trace_err()?,
3541                        )
3542                    } else {
3543                        (
3544                            true,
3545                            live_kit
3546                                .room_token(
3547                                    &joined_room.room.livekit_room,
3548                                    &session.user_id().to_string(),
3549                                )
3550                                .trace_err()?,
3551                        )
3552                    };
3553
3554                    Some(LiveKitConnectionInfo {
3555                        server_url: live_kit.url().into(),
3556                        token,
3557                        can_publish,
3558                    })
3559                });
3560
3561        response.send(proto::JoinRoomResponse {
3562            room: Some(joined_room.room.clone()),
3563            channel_id: joined_room
3564                .channel
3565                .as_ref()
3566                .map(|channel| channel.id.to_proto()),
3567            live_kit_connection_info,
3568        })?;
3569
3570        let mut connection_pool = session.connection_pool().await;
3571        if let Some(membership_updated) = membership_updated {
3572            notify_membership_updated(
3573                &mut connection_pool,
3574                membership_updated,
3575                session.user_id(),
3576                &session.peer,
3577            );
3578        }
3579
3580        room_updated(&joined_room.room, &session.peer);
3581
3582        joined_room
3583    };
3584
3585    channel_updated(
3586        &joined_room.channel.context("channel not returned")?,
3587        &joined_room.room,
3588        &session.peer,
3589        &*session.connection_pool().await,
3590    );
3591
3592    update_user_contacts(session.user_id(), &session).await?;
3593    Ok(())
3594}
3595
3596/// Start editing the channel notes
3597async fn join_channel_buffer(
3598    request: proto::JoinChannelBuffer,
3599    response: Response<proto::JoinChannelBuffer>,
3600    session: Session,
3601) -> Result<()> {
3602    let db = session.db().await;
3603    let channel_id = ChannelId::from_proto(request.channel_id);
3604
3605    let open_response = db
3606        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3607        .await?;
3608
3609    let collaborators = open_response.collaborators.clone();
3610    response.send(open_response)?;
3611
3612    let update = UpdateChannelBufferCollaborators {
3613        channel_id: channel_id.to_proto(),
3614        collaborators: collaborators.clone(),
3615    };
3616    channel_buffer_updated(
3617        session.connection_id,
3618        collaborators
3619            .iter()
3620            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3621        &update,
3622        &session.peer,
3623    );
3624
3625    Ok(())
3626}
3627
3628/// Edit the channel notes
3629async fn update_channel_buffer(
3630    request: proto::UpdateChannelBuffer,
3631    session: Session,
3632) -> Result<()> {
3633    let db = session.db().await;
3634    let channel_id = ChannelId::from_proto(request.channel_id);
3635
3636    let (collaborators, epoch, version) = db
3637        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3638        .await?;
3639
3640    channel_buffer_updated(
3641        session.connection_id,
3642        collaborators.clone(),
3643        &proto::UpdateChannelBuffer {
3644            channel_id: channel_id.to_proto(),
3645            operations: request.operations,
3646        },
3647        &session.peer,
3648    );
3649
3650    let pool = &*session.connection_pool().await;
3651
3652    let non_collaborators =
3653        pool.channel_connection_ids(channel_id)
3654            .filter_map(|(connection_id, _)| {
3655                if collaborators.contains(&connection_id) {
3656                    None
3657                } else {
3658                    Some(connection_id)
3659                }
3660            });
3661
3662    broadcast(None, non_collaborators, |peer_id| {
3663        session.peer.send(
3664            peer_id,
3665            proto::UpdateChannels {
3666                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3667                    channel_id: channel_id.to_proto(),
3668                    epoch: epoch as u64,
3669                    version: version.clone(),
3670                }],
3671                ..Default::default()
3672            },
3673        )
3674    });
3675
3676    Ok(())
3677}
3678
3679/// Rejoin the channel notes after a connection blip
3680async fn rejoin_channel_buffers(
3681    request: proto::RejoinChannelBuffers,
3682    response: Response<proto::RejoinChannelBuffers>,
3683    session: Session,
3684) -> Result<()> {
3685    let db = session.db().await;
3686    let buffers = db
3687        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3688        .await?;
3689
3690    for rejoined_buffer in &buffers {
3691        let collaborators_to_notify = rejoined_buffer
3692            .buffer
3693            .collaborators
3694            .iter()
3695            .filter_map(|c| Some(c.peer_id?.into()));
3696        channel_buffer_updated(
3697            session.connection_id,
3698            collaborators_to_notify,
3699            &proto::UpdateChannelBufferCollaborators {
3700                channel_id: rejoined_buffer.buffer.channel_id,
3701                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3702            },
3703            &session.peer,
3704        );
3705    }
3706
3707    response.send(proto::RejoinChannelBuffersResponse {
3708        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3709    })?;
3710
3711    Ok(())
3712}
3713
3714/// Stop editing the channel notes
3715async fn leave_channel_buffer(
3716    request: proto::LeaveChannelBuffer,
3717    response: Response<proto::LeaveChannelBuffer>,
3718    session: Session,
3719) -> Result<()> {
3720    let db = session.db().await;
3721    let channel_id = ChannelId::from_proto(request.channel_id);
3722
3723    let left_buffer = db
3724        .leave_channel_buffer(channel_id, session.connection_id)
3725        .await?;
3726
3727    response.send(Ack {})?;
3728
3729    channel_buffer_updated(
3730        session.connection_id,
3731        left_buffer.connections,
3732        &proto::UpdateChannelBufferCollaborators {
3733            channel_id: channel_id.to_proto(),
3734            collaborators: left_buffer.collaborators,
3735        },
3736        &session.peer,
3737    );
3738
3739    Ok(())
3740}
3741
3742fn channel_buffer_updated<T: EnvelopedMessage>(
3743    sender_id: ConnectionId,
3744    collaborators: impl IntoIterator<Item = ConnectionId>,
3745    message: &T,
3746    peer: &Peer,
3747) {
3748    broadcast(Some(sender_id), collaborators, |peer_id| {
3749        peer.send(peer_id, message.clone())
3750    });
3751}
3752
3753fn send_notifications(
3754    connection_pool: &ConnectionPool,
3755    peer: &Peer,
3756    notifications: db::NotificationBatch,
3757) {
3758    for (user_id, notification) in notifications {
3759        for connection_id in connection_pool.user_connection_ids(user_id) {
3760            if let Err(error) = peer.send(
3761                connection_id,
3762                proto::AddNotification {
3763                    notification: Some(notification.clone()),
3764                },
3765            ) {
3766                tracing::error!(
3767                    "failed to send notification to {:?} {}",
3768                    connection_id,
3769                    error
3770                );
3771            }
3772        }
3773    }
3774}
3775
3776/// Send a message to the channel
3777async fn send_channel_message(
3778    request: proto::SendChannelMessage,
3779    response: Response<proto::SendChannelMessage>,
3780    session: Session,
3781) -> Result<()> {
3782    // Validate the message body.
3783    let body = request.body.trim().to_string();
3784    if body.len() > MAX_MESSAGE_LEN {
3785        return Err(anyhow!("message is too long"))?;
3786    }
3787    if body.is_empty() {
3788        return Err(anyhow!("message can't be blank"))?;
3789    }
3790
3791    // TODO: adjust mentions if body is trimmed
3792
3793    let timestamp = OffsetDateTime::now_utc();
3794    let nonce = request.nonce.context("nonce can't be blank")?;
3795
3796    let channel_id = ChannelId::from_proto(request.channel_id);
3797    let CreatedChannelMessage {
3798        message_id,
3799        participant_connection_ids,
3800        notifications,
3801    } = session
3802        .db()
3803        .await
3804        .create_channel_message(
3805            channel_id,
3806            session.user_id(),
3807            &body,
3808            &request.mentions,
3809            timestamp,
3810            nonce.clone().into(),
3811            request.reply_to_message_id.map(MessageId::from_proto),
3812        )
3813        .await?;
3814
3815    let message = proto::ChannelMessage {
3816        sender_id: session.user_id().to_proto(),
3817        id: message_id.to_proto(),
3818        body,
3819        mentions: request.mentions,
3820        timestamp: timestamp.unix_timestamp() as u64,
3821        nonce: Some(nonce),
3822        reply_to_message_id: request.reply_to_message_id,
3823        edited_at: None,
3824    };
3825    broadcast(
3826        Some(session.connection_id),
3827        participant_connection_ids.clone(),
3828        |connection| {
3829            session.peer.send(
3830                connection,
3831                proto::ChannelMessageSent {
3832                    channel_id: channel_id.to_proto(),
3833                    message: Some(message.clone()),
3834                },
3835            )
3836        },
3837    );
3838    response.send(proto::SendChannelMessageResponse {
3839        message: Some(message),
3840    })?;
3841
3842    let pool = &*session.connection_pool().await;
3843    let non_participants =
3844        pool.channel_connection_ids(channel_id)
3845            .filter_map(|(connection_id, _)| {
3846                if participant_connection_ids.contains(&connection_id) {
3847                    None
3848                } else {
3849                    Some(connection_id)
3850                }
3851            });
3852    broadcast(None, non_participants, |peer_id| {
3853        session.peer.send(
3854            peer_id,
3855            proto::UpdateChannels {
3856                latest_channel_message_ids: vec![proto::ChannelMessageId {
3857                    channel_id: channel_id.to_proto(),
3858                    message_id: message_id.to_proto(),
3859                }],
3860                ..Default::default()
3861            },
3862        )
3863    });
3864    send_notifications(pool, &session.peer, notifications);
3865
3866    Ok(())
3867}
3868
3869/// Delete a channel message
3870async fn remove_channel_message(
3871    request: proto::RemoveChannelMessage,
3872    response: Response<proto::RemoveChannelMessage>,
3873    session: Session,
3874) -> Result<()> {
3875    let channel_id = ChannelId::from_proto(request.channel_id);
3876    let message_id = MessageId::from_proto(request.message_id);
3877    let (connection_ids, existing_notification_ids) = session
3878        .db()
3879        .await
3880        .remove_channel_message(channel_id, message_id, session.user_id())
3881        .await?;
3882
3883    broadcast(
3884        Some(session.connection_id),
3885        connection_ids,
3886        move |connection| {
3887            session.peer.send(connection, request.clone())?;
3888
3889            for notification_id in &existing_notification_ids {
3890                session.peer.send(
3891                    connection,
3892                    proto::DeleteNotification {
3893                        notification_id: (*notification_id).to_proto(),
3894                    },
3895                )?;
3896            }
3897
3898            Ok(())
3899        },
3900    );
3901    response.send(proto::Ack {})?;
3902    Ok(())
3903}
3904
3905async fn update_channel_message(
3906    request: proto::UpdateChannelMessage,
3907    response: Response<proto::UpdateChannelMessage>,
3908    session: Session,
3909) -> Result<()> {
3910    let channel_id = ChannelId::from_proto(request.channel_id);
3911    let message_id = MessageId::from_proto(request.message_id);
3912    let updated_at = OffsetDateTime::now_utc();
3913    let UpdatedChannelMessage {
3914        message_id,
3915        participant_connection_ids,
3916        notifications,
3917        reply_to_message_id,
3918        timestamp,
3919        deleted_mention_notification_ids,
3920        updated_mention_notifications,
3921    } = session
3922        .db()
3923        .await
3924        .update_channel_message(
3925            channel_id,
3926            message_id,
3927            session.user_id(),
3928            request.body.as_str(),
3929            &request.mentions,
3930            updated_at,
3931        )
3932        .await?;
3933
3934    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3935
3936    let message = proto::ChannelMessage {
3937        sender_id: session.user_id().to_proto(),
3938        id: message_id.to_proto(),
3939        body: request.body.clone(),
3940        mentions: request.mentions.clone(),
3941        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3942        nonce: Some(nonce),
3943        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3944        edited_at: Some(updated_at.unix_timestamp() as u64),
3945    };
3946
3947    response.send(proto::Ack {})?;
3948
3949    let pool = &*session.connection_pool().await;
3950    broadcast(
3951        Some(session.connection_id),
3952        participant_connection_ids,
3953        |connection| {
3954            session.peer.send(
3955                connection,
3956                proto::ChannelMessageUpdate {
3957                    channel_id: channel_id.to_proto(),
3958                    message: Some(message.clone()),
3959                },
3960            )?;
3961
3962            for notification_id in &deleted_mention_notification_ids {
3963                session.peer.send(
3964                    connection,
3965                    proto::DeleteNotification {
3966                        notification_id: (*notification_id).to_proto(),
3967                    },
3968                )?;
3969            }
3970
3971            for notification in &updated_mention_notifications {
3972                session.peer.send(
3973                    connection,
3974                    proto::UpdateNotification {
3975                        notification: Some(notification.clone()),
3976                    },
3977                )?;
3978            }
3979
3980            Ok(())
3981        },
3982    );
3983
3984    send_notifications(pool, &session.peer, notifications);
3985
3986    Ok(())
3987}
3988
3989/// Mark a channel message as read
3990async fn acknowledge_channel_message(
3991    request: proto::AckChannelMessage,
3992    session: Session,
3993) -> Result<()> {
3994    let channel_id = ChannelId::from_proto(request.channel_id);
3995    let message_id = MessageId::from_proto(request.message_id);
3996    let notifications = session
3997        .db()
3998        .await
3999        .observe_channel_message(channel_id, session.user_id(), message_id)
4000        .await?;
4001    send_notifications(
4002        &*session.connection_pool().await,
4003        &session.peer,
4004        notifications,
4005    );
4006    Ok(())
4007}
4008
4009/// Mark a buffer version as synced
4010async fn acknowledge_buffer_version(
4011    request: proto::AckBufferOperation,
4012    session: Session,
4013) -> Result<()> {
4014    let buffer_id = BufferId::from_proto(request.buffer_id);
4015    session
4016        .db()
4017        .await
4018        .observe_buffer_version(
4019            buffer_id,
4020            session.user_id(),
4021            request.epoch as i32,
4022            &request.version,
4023        )
4024        .await?;
4025    Ok(())
4026}
4027
4028/// Get a Supermaven API key for the user
4029async fn get_supermaven_api_key(
4030    _request: proto::GetSupermavenApiKey,
4031    response: Response<proto::GetSupermavenApiKey>,
4032    session: Session,
4033) -> Result<()> {
4034    let user_id: String = session.user_id().to_string();
4035    if !session.is_staff() {
4036        return Err(anyhow!("supermaven not enabled for this account"))?;
4037    }
4038
4039    let email = session.email().context("user must have an email")?;
4040
4041    let supermaven_admin_api = session
4042        .supermaven_client
4043        .as_ref()
4044        .context("supermaven not configured")?;
4045
4046    let result = supermaven_admin_api
4047        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4048        .await?;
4049
4050    response.send(proto::GetSupermavenApiKeyResponse {
4051        api_key: result.api_key,
4052    })?;
4053
4054    Ok(())
4055}
4056
4057/// Start receiving chat updates for a channel
4058async fn join_channel_chat(
4059    request: proto::JoinChannelChat,
4060    response: Response<proto::JoinChannelChat>,
4061    session: Session,
4062) -> Result<()> {
4063    let channel_id = ChannelId::from_proto(request.channel_id);
4064
4065    let db = session.db().await;
4066    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4067        .await?;
4068    let messages = db
4069        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4070        .await?;
4071    response.send(proto::JoinChannelChatResponse {
4072        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4073        messages,
4074    })?;
4075    Ok(())
4076}
4077
4078/// Stop receiving chat updates for a channel
4079async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4080    let channel_id = ChannelId::from_proto(request.channel_id);
4081    session
4082        .db()
4083        .await
4084        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4085        .await?;
4086    Ok(())
4087}
4088
4089/// Retrieve the chat history for a channel
4090async fn get_channel_messages(
4091    request: proto::GetChannelMessages,
4092    response: Response<proto::GetChannelMessages>,
4093    session: Session,
4094) -> Result<()> {
4095    let channel_id = ChannelId::from_proto(request.channel_id);
4096    let messages = session
4097        .db()
4098        .await
4099        .get_channel_messages(
4100            channel_id,
4101            session.user_id(),
4102            MESSAGE_COUNT_PER_PAGE,
4103            Some(MessageId::from_proto(request.before_message_id)),
4104        )
4105        .await?;
4106    response.send(proto::GetChannelMessagesResponse {
4107        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4108        messages,
4109    })?;
4110    Ok(())
4111}
4112
4113/// Retrieve specific chat messages
4114async fn get_channel_messages_by_id(
4115    request: proto::GetChannelMessagesById,
4116    response: Response<proto::GetChannelMessagesById>,
4117    session: Session,
4118) -> Result<()> {
4119    let message_ids = request
4120        .message_ids
4121        .iter()
4122        .map(|id| MessageId::from_proto(*id))
4123        .collect::<Vec<_>>();
4124    let messages = session
4125        .db()
4126        .await
4127        .get_channel_messages_by_id(session.user_id(), &message_ids)
4128        .await?;
4129    response.send(proto::GetChannelMessagesResponse {
4130        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4131        messages,
4132    })?;
4133    Ok(())
4134}
4135
4136/// Retrieve the current users notifications
4137async fn get_notifications(
4138    request: proto::GetNotifications,
4139    response: Response<proto::GetNotifications>,
4140    session: Session,
4141) -> Result<()> {
4142    let notifications = session
4143        .db()
4144        .await
4145        .get_notifications(
4146            session.user_id(),
4147            NOTIFICATION_COUNT_PER_PAGE,
4148            request.before_id.map(db::NotificationId::from_proto),
4149        )
4150        .await?;
4151    response.send(proto::GetNotificationsResponse {
4152        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4153        notifications,
4154    })?;
4155    Ok(())
4156}
4157
4158/// Mark notifications as read
4159async fn mark_notification_as_read(
4160    request: proto::MarkNotificationRead,
4161    response: Response<proto::MarkNotificationRead>,
4162    session: Session,
4163) -> Result<()> {
4164    let database = &session.db().await;
4165    let notifications = database
4166        .mark_notification_as_read_by_id(
4167            session.user_id(),
4168            NotificationId::from_proto(request.notification_id),
4169        )
4170        .await?;
4171    send_notifications(
4172        &*session.connection_pool().await,
4173        &session.peer,
4174        notifications,
4175    );
4176    response.send(proto::Ack {})?;
4177    Ok(())
4178}
4179
4180/// Get the current users information
4181async fn get_private_user_info(
4182    _request: proto::GetPrivateUserInfo,
4183    response: Response<proto::GetPrivateUserInfo>,
4184    session: Session,
4185) -> Result<()> {
4186    let db = session.db().await;
4187
4188    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4189    let user = db
4190        .get_user_by_id(session.user_id())
4191        .await?
4192        .context("user not found")?;
4193    let flags = db.get_user_flags(session.user_id()).await?;
4194
4195    response.send(proto::GetPrivateUserInfoResponse {
4196        metrics_id,
4197        staff: user.admin,
4198        flags,
4199        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4200    })?;
4201    Ok(())
4202}
4203
4204/// Accept the terms of service (tos) on behalf of the current user
4205async fn accept_terms_of_service(
4206    _request: proto::AcceptTermsOfService,
4207    response: Response<proto::AcceptTermsOfService>,
4208    session: Session,
4209) -> Result<()> {
4210    let db = session.db().await;
4211
4212    let accepted_tos_at = Utc::now();
4213    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4214        .await?;
4215
4216    response.send(proto::AcceptTermsOfServiceResponse {
4217        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4218    })?;
4219
4220    // When the user accepts the terms of service, we want to refresh their LLM
4221    // token to grant access.
4222    session
4223        .peer
4224        .send(session.connection_id, proto::RefreshLlmToken {})?;
4225
4226    Ok(())
4227}
4228
4229async fn get_llm_api_token(
4230    _request: proto::GetLlmToken,
4231    response: Response<proto::GetLlmToken>,
4232    session: Session,
4233) -> Result<()> {
4234    let db = session.db().await;
4235
4236    let flags = db.get_user_flags(session.user_id()).await?;
4237
4238    let user_id = session.user_id();
4239    let user = db
4240        .get_user_by_id(user_id)
4241        .await?
4242        .with_context(|| format!("user {user_id} not found"))?;
4243
4244    if user.accepted_tos_at.is_none() {
4245        Err(anyhow!("terms of service not accepted"))?
4246    }
4247
4248    let stripe_client = session
4249        .app_state
4250        .stripe_client
4251        .as_ref()
4252        .context("failed to retrieve Stripe client")?;
4253
4254    let stripe_billing = session
4255        .app_state
4256        .stripe_billing
4257        .as_ref()
4258        .context("failed to retrieve Stripe billing object")?;
4259
4260    let billing_customer = if let Some(billing_customer) =
4261        db.get_billing_customer_by_user_id(user.id).await?
4262    {
4263        billing_customer
4264    } else {
4265        let customer_id = stripe_billing
4266            .find_or_create_customer_by_email(user.email_address.as_deref())
4267            .await?;
4268
4269        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4270            .await?
4271            .context("billing customer not found")?
4272    };
4273
4274    let billing_subscription =
4275        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4276            billing_subscription
4277        } else {
4278            let stripe_customer_id =
4279                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4280
4281            let stripe_subscription = stripe_billing
4282                .subscribe_to_zed_free(stripe_customer_id)
4283                .await?;
4284
4285            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4286                billing_customer_id: billing_customer.id,
4287                kind: Some(SubscriptionKind::ZedFree),
4288                stripe_subscription_id: stripe_subscription.id.to_string(),
4289                stripe_subscription_status: stripe_subscription.status.into(),
4290                stripe_cancellation_reason: None,
4291                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4292                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4293            })
4294            .await?
4295        };
4296
4297    let billing_preferences = db.get_billing_preferences(user.id).await?;
4298
4299    let token = LlmTokenClaims::create(
4300        &user,
4301        session.is_staff(),
4302        billing_customer,
4303        billing_preferences,
4304        &flags,
4305        billing_subscription,
4306        session.system_id.clone(),
4307        &session.app_state.config,
4308    )?;
4309    response.send(proto::GetLlmTokenResponse { token })?;
4310    Ok(())
4311}
4312
4313fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4314    let message = match message {
4315        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4316        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4317        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4318        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4319        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4320            code: frame.code.into(),
4321            reason: frame.reason.as_str().to_owned().into(),
4322        })),
4323        // We should never receive a frame while reading the message, according
4324        // to the `tungstenite` maintainers:
4325        //
4326        // > It cannot occur when you read messages from the WebSocket, but it
4327        // > can be used when you want to send the raw frames (e.g. you want to
4328        // > send the frames to the WebSocket without composing the full message first).
4329        // >
4330        // > — https://github.com/snapview/tungstenite-rs/issues/268
4331        TungsteniteMessage::Frame(_) => {
4332            bail!("received an unexpected frame while reading the message")
4333        }
4334    };
4335
4336    Ok(message)
4337}
4338
4339fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4340    match message {
4341        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4342        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4343        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4344        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4345        AxumMessage::Close(frame) => {
4346            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4347                code: frame.code.into(),
4348                reason: frame.reason.as_ref().into(),
4349            }))
4350        }
4351    }
4352}
4353
4354fn notify_membership_updated(
4355    connection_pool: &mut ConnectionPool,
4356    result: MembershipUpdated,
4357    user_id: UserId,
4358    peer: &Peer,
4359) {
4360    for membership in &result.new_channels.channel_memberships {
4361        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4362    }
4363    for channel_id in &result.removed_channels {
4364        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4365    }
4366
4367    let user_channels_update = proto::UpdateUserChannels {
4368        channel_memberships: result
4369            .new_channels
4370            .channel_memberships
4371            .iter()
4372            .map(|cm| proto::ChannelMembership {
4373                channel_id: cm.channel_id.to_proto(),
4374                role: cm.role.into(),
4375            })
4376            .collect(),
4377        ..Default::default()
4378    };
4379
4380    let mut update = build_channels_update(result.new_channels);
4381    update.delete_channels = result
4382        .removed_channels
4383        .into_iter()
4384        .map(|id| id.to_proto())
4385        .collect();
4386    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4387
4388    for connection_id in connection_pool.user_connection_ids(user_id) {
4389        peer.send(connection_id, user_channels_update.clone())
4390            .trace_err();
4391        peer.send(connection_id, update.clone()).trace_err();
4392    }
4393}
4394
4395fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4396    proto::UpdateUserChannels {
4397        channel_memberships: channels
4398            .channel_memberships
4399            .iter()
4400            .map(|m| proto::ChannelMembership {
4401                channel_id: m.channel_id.to_proto(),
4402                role: m.role.into(),
4403            })
4404            .collect(),
4405        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4406        observed_channel_message_id: channels.observed_channel_messages.clone(),
4407    }
4408}
4409
4410fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4411    let mut update = proto::UpdateChannels::default();
4412
4413    for channel in channels.channels {
4414        update.channels.push(channel.to_proto());
4415    }
4416
4417    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4418    update.latest_channel_message_ids = channels.latest_channel_messages;
4419
4420    for (channel_id, participants) in channels.channel_participants {
4421        update
4422            .channel_participants
4423            .push(proto::ChannelParticipants {
4424                channel_id: channel_id.to_proto(),
4425                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4426            });
4427    }
4428
4429    for channel in channels.invited_channels {
4430        update.channel_invitations.push(channel.to_proto());
4431    }
4432
4433    update
4434}
4435
4436fn build_initial_contacts_update(
4437    contacts: Vec<db::Contact>,
4438    pool: &ConnectionPool,
4439) -> proto::UpdateContacts {
4440    let mut update = proto::UpdateContacts::default();
4441
4442    for contact in contacts {
4443        match contact {
4444            db::Contact::Accepted { user_id, busy } => {
4445                update.contacts.push(contact_for_user(user_id, busy, pool));
4446            }
4447            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4448            db::Contact::Incoming { user_id } => {
4449                update
4450                    .incoming_requests
4451                    .push(proto::IncomingContactRequest {
4452                        requester_id: user_id.to_proto(),
4453                    })
4454            }
4455        }
4456    }
4457
4458    update
4459}
4460
4461fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4462    proto::Contact {
4463        user_id: user_id.to_proto(),
4464        online: pool.is_user_online(user_id),
4465        busy,
4466    }
4467}
4468
4469fn room_updated(room: &proto::Room, peer: &Peer) {
4470    broadcast(
4471        None,
4472        room.participants
4473            .iter()
4474            .filter_map(|participant| Some(participant.peer_id?.into())),
4475        |peer_id| {
4476            peer.send(
4477                peer_id,
4478                proto::RoomUpdated {
4479                    room: Some(room.clone()),
4480                },
4481            )
4482        },
4483    );
4484}
4485
4486fn channel_updated(
4487    channel: &db::channel::Model,
4488    room: &proto::Room,
4489    peer: &Peer,
4490    pool: &ConnectionPool,
4491) {
4492    let participants = room
4493        .participants
4494        .iter()
4495        .map(|p| p.user_id)
4496        .collect::<Vec<_>>();
4497
4498    broadcast(
4499        None,
4500        pool.channel_connection_ids(channel.root_id())
4501            .filter_map(|(channel_id, role)| {
4502                role.can_see_channel(channel.visibility)
4503                    .then_some(channel_id)
4504            }),
4505        |peer_id| {
4506            peer.send(
4507                peer_id,
4508                proto::UpdateChannels {
4509                    channel_participants: vec![proto::ChannelParticipants {
4510                        channel_id: channel.id.to_proto(),
4511                        participant_user_ids: participants.clone(),
4512                    }],
4513                    ..Default::default()
4514                },
4515            )
4516        },
4517    );
4518}
4519
4520async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4521    let db = session.db().await;
4522
4523    let contacts = db.get_contacts(user_id).await?;
4524    let busy = db.is_user_busy(user_id).await?;
4525
4526    let pool = session.connection_pool().await;
4527    let updated_contact = contact_for_user(user_id, busy, &pool);
4528    for contact in contacts {
4529        if let db::Contact::Accepted {
4530            user_id: contact_user_id,
4531            ..
4532        } = contact
4533        {
4534            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4535                session
4536                    .peer
4537                    .send(
4538                        contact_conn_id,
4539                        proto::UpdateContacts {
4540                            contacts: vec![updated_contact.clone()],
4541                            remove_contacts: Default::default(),
4542                            incoming_requests: Default::default(),
4543                            remove_incoming_requests: Default::default(),
4544                            outgoing_requests: Default::default(),
4545                            remove_outgoing_requests: Default::default(),
4546                        },
4547                    )
4548                    .trace_err();
4549            }
4550        }
4551    }
4552    Ok(())
4553}
4554
4555async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4556    let mut contacts_to_update = HashSet::default();
4557
4558    let room_id;
4559    let canceled_calls_to_user_ids;
4560    let livekit_room;
4561    let delete_livekit_room;
4562    let room;
4563    let channel;
4564
4565    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4566        contacts_to_update.insert(session.user_id());
4567
4568        for project in left_room.left_projects.values() {
4569            project_left(project, session);
4570        }
4571
4572        room_id = RoomId::from_proto(left_room.room.id);
4573        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4574        livekit_room = mem::take(&mut left_room.room.livekit_room);
4575        delete_livekit_room = left_room.deleted;
4576        room = mem::take(&mut left_room.room);
4577        channel = mem::take(&mut left_room.channel);
4578
4579        room_updated(&room, &session.peer);
4580    } else {
4581        return Ok(());
4582    }
4583
4584    if let Some(channel) = channel {
4585        channel_updated(
4586            &channel,
4587            &room,
4588            &session.peer,
4589            &*session.connection_pool().await,
4590        );
4591    }
4592
4593    {
4594        let pool = session.connection_pool().await;
4595        for canceled_user_id in canceled_calls_to_user_ids {
4596            for connection_id in pool.user_connection_ids(canceled_user_id) {
4597                session
4598                    .peer
4599                    .send(
4600                        connection_id,
4601                        proto::CallCanceled {
4602                            room_id: room_id.to_proto(),
4603                        },
4604                    )
4605                    .trace_err();
4606            }
4607            contacts_to_update.insert(canceled_user_id);
4608        }
4609    }
4610
4611    for contact_user_id in contacts_to_update {
4612        update_user_contacts(contact_user_id, session).await?;
4613    }
4614
4615    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4616        live_kit
4617            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4618            .await
4619            .trace_err();
4620
4621        if delete_livekit_room {
4622            live_kit.delete_room(livekit_room).await.trace_err();
4623        }
4624    }
4625
4626    Ok(())
4627}
4628
4629async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4630    let left_channel_buffers = session
4631        .db()
4632        .await
4633        .leave_channel_buffers(session.connection_id)
4634        .await?;
4635
4636    for left_buffer in left_channel_buffers {
4637        channel_buffer_updated(
4638            session.connection_id,
4639            left_buffer.connections,
4640            &proto::UpdateChannelBufferCollaborators {
4641                channel_id: left_buffer.channel_id.to_proto(),
4642                collaborators: left_buffer.collaborators,
4643            },
4644            &session.peer,
4645        );
4646    }
4647
4648    Ok(())
4649}
4650
4651fn project_left(project: &db::LeftProject, session: &Session) {
4652    for connection_id in &project.connection_ids {
4653        if project.should_unshare {
4654            session
4655                .peer
4656                .send(
4657                    *connection_id,
4658                    proto::UnshareProject {
4659                        project_id: project.id.to_proto(),
4660                    },
4661                )
4662                .trace_err();
4663        } else {
4664            session
4665                .peer
4666                .send(
4667                    *connection_id,
4668                    proto::RemoveProjectCollaborator {
4669                        project_id: project.id.to_proto(),
4670                        peer_id: Some(session.connection_id.into()),
4671                    },
4672                )
4673                .trace_err();
4674        }
4675    }
4676}
4677
4678pub trait ResultExt {
4679    type Ok;
4680
4681    fn trace_err(self) -> Option<Self::Ok>;
4682}
4683
4684impl<T, E> ResultExt for Result<T, E>
4685where
4686    E: std::fmt::Debug,
4687{
4688    type Ok = T;
4689
4690    #[track_caller]
4691    fn trace_err(self) -> Option<T> {
4692        match self {
4693            Ok(value) => Some(value),
4694            Err(error) => {
4695                tracing::error!("{:?}", error);
4696                None
4697            }
4698        }
4699    }
4700}