rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{
   8    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
   9    MIN_ACCOUNT_AGE_FOR_LLM_USE,
  10};
  11use crate::stripe_client::StripeCustomerId;
  12use crate::{
  13    AppState, Error, Result, auth,
  14    db::{
  15        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  16        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  17        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  18        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  19    },
  20    executor::Executor,
  21};
  22use anyhow::{Context as _, anyhow, bail};
  23use async_tungstenite::tungstenite::{
  24    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  25};
  26use axum::headers::UserAgent;
  27use axum::{
  28    Extension, Router, TypedHeader,
  29    body::Body,
  30    extract::{
  31        ConnectInfo, WebSocketUpgrade,
  32        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  33    },
  34    headers::{Header, HeaderName},
  35    http::StatusCode,
  36    middleware,
  37    response::IntoResponse,
  38    routing::get,
  39};
  40use chrono::Utc;
  41use collections::{HashMap, HashSet};
  42pub use connection_pool::{ConnectionPool, ZedVersion};
  43use core::fmt::{self, Debug, Formatter};
  44use reqwest_client::ReqwestClient;
  45use rpc::proto::{MultiLspQuery, split_repository_update};
  46use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  47
  48use futures::{
  49    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  50    stream::FuturesUnordered,
  51};
  52use prometheus::{IntGauge, register_int_gauge};
  53use rpc::{
  54    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  55    proto::{
  56        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  57        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  58    },
  59};
  60use semantic_version::SemanticVersion;
  61use serde::{Serialize, Serializer};
  62use std::{
  63    any::TypeId,
  64    future::Future,
  65    marker::PhantomData,
  66    mem,
  67    net::SocketAddr,
  68    ops::{Deref, DerefMut},
  69    rc::Rc,
  70    sync::{
  71        Arc, OnceLock,
  72        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  73    },
  74    time::{Duration, Instant},
  75};
  76use time::OffsetDateTime;
  77use tokio::sync::{Semaphore, watch};
  78use tower::ServiceBuilder;
  79use tracing::{
  80    Instrument,
  81    field::{self},
  82    info_span, instrument,
  83};
  84
  85pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  86
  87// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  88pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  89
  90const MESSAGE_COUNT_PER_PAGE: usize = 100;
  91const MAX_MESSAGE_LEN: usize = 1024;
  92const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  93const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  94
  95static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  96
  97type MessageHandler =
  98    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  99
 100pub struct ConnectionGuard;
 101
 102impl ConnectionGuard {
 103    pub fn try_acquire() -> Result<Self, ()> {
 104        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 105        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 106            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 107            tracing::error!(
 108                "too many concurrent connections: {}",
 109                current_connections + 1
 110            );
 111            return Err(());
 112        }
 113        Ok(ConnectionGuard)
 114    }
 115}
 116
 117impl Drop for ConnectionGuard {
 118    fn drop(&mut self) {
 119        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 120    }
 121}
 122
 123struct Response<R> {
 124    peer: Arc<Peer>,
 125    receipt: Receipt<R>,
 126    responded: Arc<AtomicBool>,
 127}
 128
 129impl<R: RequestMessage> Response<R> {
 130    fn send(self, payload: R::Response) -> Result<()> {
 131        self.responded.store(true, SeqCst);
 132        self.peer.respond(self.receipt, payload)?;
 133        Ok(())
 134    }
 135}
 136
 137#[derive(Clone, Debug)]
 138pub enum Principal {
 139    User(User),
 140    Impersonated { user: User, admin: User },
 141}
 142
 143impl Principal {
 144    fn user(&self) -> &User {
 145        match self {
 146            Principal::User(user) => user,
 147            Principal::Impersonated { user, .. } => user,
 148        }
 149    }
 150
 151    fn update_span(&self, span: &tracing::Span) {
 152        match &self {
 153            Principal::User(user) => {
 154                span.record("user_id", user.id.0);
 155                span.record("login", &user.github_login);
 156            }
 157            Principal::Impersonated { user, admin } => {
 158                span.record("user_id", user.id.0);
 159                span.record("login", &user.github_login);
 160                span.record("impersonator", &admin.github_login);
 161            }
 162        }
 163    }
 164}
 165
 166#[derive(Clone)]
 167struct Session {
 168    principal: Principal,
 169    connection_id: ConnectionId,
 170    db: Arc<tokio::sync::Mutex<DbHandle>>,
 171    peer: Arc<Peer>,
 172    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 173    app_state: Arc<AppState>,
 174    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 175    /// The GeoIP country code for the user.
 176    #[allow(unused)]
 177    geoip_country_code: Option<String>,
 178    system_id: Option<String>,
 179    _executor: Executor,
 180}
 181
 182impl Session {
 183    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 184        #[cfg(test)]
 185        tokio::task::yield_now().await;
 186        let guard = self.db.lock().await;
 187        #[cfg(test)]
 188        tokio::task::yield_now().await;
 189        guard
 190    }
 191
 192    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 193        #[cfg(test)]
 194        tokio::task::yield_now().await;
 195        let guard = self.connection_pool.lock();
 196        ConnectionPoolGuard {
 197            guard,
 198            _not_send: PhantomData,
 199        }
 200    }
 201
 202    fn is_staff(&self) -> bool {
 203        match &self.principal {
 204            Principal::User(user) => user.admin,
 205            Principal::Impersonated { .. } => true,
 206        }
 207    }
 208
 209    fn user_id(&self) -> UserId {
 210        match &self.principal {
 211            Principal::User(user) => user.id,
 212            Principal::Impersonated { user, .. } => user.id,
 213        }
 214    }
 215
 216    pub fn email(&self) -> Option<String> {
 217        match &self.principal {
 218            Principal::User(user) => user.email_address.clone(),
 219            Principal::Impersonated { user, .. } => user.email_address.clone(),
 220        }
 221    }
 222}
 223
 224impl Debug for Session {
 225    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 226        let mut result = f.debug_struct("Session");
 227        match &self.principal {
 228            Principal::User(user) => {
 229                result.field("user", &user.github_login);
 230            }
 231            Principal::Impersonated { user, admin } => {
 232                result.field("user", &user.github_login);
 233                result.field("impersonator", &admin.github_login);
 234            }
 235        }
 236        result.field("connection_id", &self.connection_id).finish()
 237    }
 238}
 239
 240struct DbHandle(Arc<Database>);
 241
 242impl Deref for DbHandle {
 243    type Target = Database;
 244
 245    fn deref(&self) -> &Self::Target {
 246        self.0.as_ref()
 247    }
 248}
 249
 250pub struct Server {
 251    id: parking_lot::Mutex<ServerId>,
 252    peer: Arc<Peer>,
 253    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 254    app_state: Arc<AppState>,
 255    handlers: HashMap<TypeId, MessageHandler>,
 256    teardown: watch::Sender<bool>,
 257}
 258
 259pub(crate) struct ConnectionPoolGuard<'a> {
 260    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 261    _not_send: PhantomData<Rc<()>>,
 262}
 263
 264#[derive(Serialize)]
 265pub struct ServerSnapshot<'a> {
 266    peer: &'a Peer,
 267    #[serde(serialize_with = "serialize_deref")]
 268    connection_pool: ConnectionPoolGuard<'a>,
 269}
 270
 271pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 272where
 273    S: Serializer,
 274    T: Deref<Target = U>,
 275    U: Serialize,
 276{
 277    Serialize::serialize(value.deref(), serializer)
 278}
 279
 280impl Server {
 281    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 282        let mut server = Self {
 283            id: parking_lot::Mutex::new(id),
 284            peer: Peer::new(id.0 as u32),
 285            app_state: app_state.clone(),
 286            connection_pool: Default::default(),
 287            handlers: Default::default(),
 288            teardown: watch::channel(false).0,
 289        };
 290
 291        server
 292            .add_request_handler(ping)
 293            .add_request_handler(create_room)
 294            .add_request_handler(join_room)
 295            .add_request_handler(rejoin_room)
 296            .add_request_handler(leave_room)
 297            .add_request_handler(set_room_participant_role)
 298            .add_request_handler(call)
 299            .add_request_handler(cancel_call)
 300            .add_message_handler(decline_call)
 301            .add_request_handler(update_participant_location)
 302            .add_request_handler(share_project)
 303            .add_message_handler(unshare_project)
 304            .add_request_handler(join_project)
 305            .add_message_handler(leave_project)
 306            .add_request_handler(update_project)
 307            .add_request_handler(update_worktree)
 308            .add_request_handler(update_repository)
 309            .add_request_handler(remove_repository)
 310            .add_message_handler(start_language_server)
 311            .add_message_handler(update_language_server)
 312            .add_message_handler(update_diagnostic_summary)
 313            .add_message_handler(update_worktree_settings)
 314            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 316            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 317            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 318            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 319            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 320            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 321            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 322            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 323            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 324            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 325            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 326            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 327            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 328            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 329            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 330            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 331            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 332            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 333            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 334            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 335            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 336            .add_request_handler(
 337                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 338            )
 339            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 340            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 341            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 342            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 343            .add_request_handler(
 344                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 345            )
 346            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 347            .add_request_handler(
 348                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 349            )
 350            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 351            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 352            .add_request_handler(
 353                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 354            )
 355            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 356            .add_request_handler(
 357                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 358            )
 359            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 360            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 361            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 362            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 363            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 364            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 365            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 366            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 367            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 368            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 369            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 370            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 371            .add_request_handler(
 372                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 373            )
 374            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 375            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 376            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 377            .add_request_handler(multi_lsp_query)
 378            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 379            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 380            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 381            .add_message_handler(create_buffer_for_peer)
 382            .add_request_handler(update_buffer)
 383            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 384            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 385            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 386            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 387            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 388            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 389            .add_message_handler(
 390                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 391            )
 392            .add_request_handler(get_users)
 393            .add_request_handler(fuzzy_search_users)
 394            .add_request_handler(request_contact)
 395            .add_request_handler(remove_contact)
 396            .add_request_handler(respond_to_contact_request)
 397            .add_message_handler(subscribe_to_channels)
 398            .add_request_handler(create_channel)
 399            .add_request_handler(delete_channel)
 400            .add_request_handler(invite_channel_member)
 401            .add_request_handler(remove_channel_member)
 402            .add_request_handler(set_channel_member_role)
 403            .add_request_handler(set_channel_visibility)
 404            .add_request_handler(rename_channel)
 405            .add_request_handler(join_channel_buffer)
 406            .add_request_handler(leave_channel_buffer)
 407            .add_message_handler(update_channel_buffer)
 408            .add_request_handler(rejoin_channel_buffers)
 409            .add_request_handler(get_channel_members)
 410            .add_request_handler(respond_to_channel_invite)
 411            .add_request_handler(join_channel)
 412            .add_request_handler(join_channel_chat)
 413            .add_message_handler(leave_channel_chat)
 414            .add_request_handler(send_channel_message)
 415            .add_request_handler(remove_channel_message)
 416            .add_request_handler(update_channel_message)
 417            .add_request_handler(get_channel_messages)
 418            .add_request_handler(get_channel_messages_by_id)
 419            .add_request_handler(get_notifications)
 420            .add_request_handler(mark_notification_as_read)
 421            .add_request_handler(move_channel)
 422            .add_request_handler(reorder_channel)
 423            .add_request_handler(follow)
 424            .add_message_handler(unfollow)
 425            .add_message_handler(update_followers)
 426            .add_request_handler(get_private_user_info)
 427            .add_request_handler(get_llm_api_token)
 428            .add_request_handler(accept_terms_of_service)
 429            .add_message_handler(acknowledge_channel_message)
 430            .add_message_handler(acknowledge_buffer_version)
 431            .add_request_handler(get_supermaven_api_key)
 432            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 433            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 434            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 435            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 436            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 437            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 438            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 439            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 440            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 441            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 442            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 443            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 444            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 445            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 446            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 447            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 448            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 449            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 450            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 451            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 452            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 453            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 454            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 455            .add_message_handler(update_context);
 456
 457        Arc::new(server)
 458    }
 459
 460    pub async fn start(&self) -> Result<()> {
 461        let server_id = *self.id.lock();
 462        let app_state = self.app_state.clone();
 463        let peer = self.peer.clone();
 464        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 465        let pool = self.connection_pool.clone();
 466        let livekit_client = self.app_state.livekit_client.clone();
 467
 468        let span = info_span!("start server");
 469        self.app_state.executor.spawn_detached(
 470            async move {
 471                tracing::info!("waiting for cleanup timeout");
 472                timeout.await;
 473                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 474
 475                app_state
 476                    .db
 477                    .delete_stale_channel_chat_participants(
 478                        &app_state.config.zed_environment,
 479                        server_id,
 480                    )
 481                    .await
 482                    .trace_err();
 483
 484                if let Some((room_ids, channel_ids)) = app_state
 485                    .db
 486                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 487                    .await
 488                    .trace_err()
 489                {
 490                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 491                    tracing::info!(
 492                        stale_channel_buffer_count = channel_ids.len(),
 493                        "retrieved stale channel buffers"
 494                    );
 495
 496                    for channel_id in channel_ids {
 497                        if let Some(refreshed_channel_buffer) = app_state
 498                            .db
 499                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 500                            .await
 501                            .trace_err()
 502                        {
 503                            for connection_id in refreshed_channel_buffer.connection_ids {
 504                                peer.send(
 505                                    connection_id,
 506                                    proto::UpdateChannelBufferCollaborators {
 507                                        channel_id: channel_id.to_proto(),
 508                                        collaborators: refreshed_channel_buffer
 509                                            .collaborators
 510                                            .clone(),
 511                                    },
 512                                )
 513                                .trace_err();
 514                            }
 515                        }
 516                    }
 517
 518                    for room_id in room_ids {
 519                        let mut contacts_to_update = HashSet::default();
 520                        let mut canceled_calls_to_user_ids = Vec::new();
 521                        let mut livekit_room = String::new();
 522                        let mut delete_livekit_room = false;
 523
 524                        if let Some(mut refreshed_room) = app_state
 525                            .db
 526                            .clear_stale_room_participants(room_id, server_id)
 527                            .await
 528                            .trace_err()
 529                        {
 530                            tracing::info!(
 531                                room_id = room_id.0,
 532                                new_participant_count = refreshed_room.room.participants.len(),
 533                                "refreshed room"
 534                            );
 535                            room_updated(&refreshed_room.room, &peer);
 536                            if let Some(channel) = refreshed_room.channel.as_ref() {
 537                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 538                            }
 539                            contacts_to_update
 540                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 541                            contacts_to_update
 542                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 543                            canceled_calls_to_user_ids =
 544                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 545                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 546                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 547                        }
 548
 549                        {
 550                            let pool = pool.lock();
 551                            for canceled_user_id in canceled_calls_to_user_ids {
 552                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 553                                    peer.send(
 554                                        connection_id,
 555                                        proto::CallCanceled {
 556                                            room_id: room_id.to_proto(),
 557                                        },
 558                                    )
 559                                    .trace_err();
 560                                }
 561                            }
 562                        }
 563
 564                        for user_id in contacts_to_update {
 565                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 566                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 567                            if let Some((busy, contacts)) = busy.zip(contacts) {
 568                                let pool = pool.lock();
 569                                let updated_contact = contact_for_user(user_id, busy, &pool);
 570                                for contact in contacts {
 571                                    if let db::Contact::Accepted {
 572                                        user_id: contact_user_id,
 573                                        ..
 574                                    } = contact
 575                                    {
 576                                        for contact_conn_id in
 577                                            pool.user_connection_ids(contact_user_id)
 578                                        {
 579                                            peer.send(
 580                                                contact_conn_id,
 581                                                proto::UpdateContacts {
 582                                                    contacts: vec![updated_contact.clone()],
 583                                                    remove_contacts: Default::default(),
 584                                                    incoming_requests: Default::default(),
 585                                                    remove_incoming_requests: Default::default(),
 586                                                    outgoing_requests: Default::default(),
 587                                                    remove_outgoing_requests: Default::default(),
 588                                                },
 589                                            )
 590                                            .trace_err();
 591                                        }
 592                                    }
 593                                }
 594                            }
 595                        }
 596
 597                        if let Some(live_kit) = livekit_client.as_ref() {
 598                            if delete_livekit_room {
 599                                live_kit.delete_room(livekit_room).await.trace_err();
 600                            }
 601                        }
 602                    }
 603                }
 604
 605                app_state
 606                    .db
 607                    .delete_stale_channel_chat_participants(
 608                        &app_state.config.zed_environment,
 609                        server_id,
 610                    )
 611                    .await
 612                    .trace_err();
 613
 614                app_state
 615                    .db
 616                    .clear_old_worktree_entries(server_id)
 617                    .await
 618                    .trace_err();
 619
 620                app_state
 621                    .db
 622                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 623                    .await
 624                    .trace_err();
 625            }
 626            .instrument(span),
 627        );
 628        Ok(())
 629    }
 630
 631    pub fn teardown(&self) {
 632        self.peer.teardown();
 633        self.connection_pool.lock().reset();
 634        let _ = self.teardown.send(true);
 635    }
 636
 637    #[cfg(test)]
 638    pub fn reset(&self, id: ServerId) {
 639        self.teardown();
 640        *self.id.lock() = id;
 641        self.peer.reset(id.0 as u32);
 642        let _ = self.teardown.send(false);
 643    }
 644
 645    #[cfg(test)]
 646    pub fn id(&self) -> ServerId {
 647        *self.id.lock()
 648    }
 649
 650    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 651    where
 652        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 653        Fut: 'static + Send + Future<Output = Result<()>>,
 654        M: EnvelopedMessage,
 655    {
 656        let prev_handler = self.handlers.insert(
 657            TypeId::of::<M>(),
 658            Box::new(move |envelope, session| {
 659                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 660                let received_at = envelope.received_at;
 661                tracing::info!("message received");
 662                let start_time = Instant::now();
 663                let future = (handler)(*envelope, session);
 664                async move {
 665                    let result = future.await;
 666                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 667                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 668                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 669
 670                    match result {
 671                        Err(error) => {
 672                            tracing::error!(
 673                                ?error,
 674                                total_duration_ms,
 675                                processing_duration_ms,
 676                                queue_duration_ms,
 677                                "error handling message"
 678                            )
 679                        }
 680                        Ok(()) => tracing::info!(
 681                            total_duration_ms,
 682                            processing_duration_ms,
 683                            queue_duration_ms,
 684                            "finished handling message"
 685                        ),
 686                    }
 687                }
 688                .boxed()
 689            }),
 690        );
 691        if prev_handler.is_some() {
 692            panic!("registered a handler for the same message twice");
 693        }
 694        self
 695    }
 696
 697    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 698    where
 699        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 700        Fut: 'static + Send + Future<Output = Result<()>>,
 701        M: EnvelopedMessage,
 702    {
 703        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 704        self
 705    }
 706
 707    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 708    where
 709        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 710        Fut: Send + Future<Output = Result<()>>,
 711        M: RequestMessage,
 712    {
 713        let handler = Arc::new(handler);
 714        self.add_handler(move |envelope, session| {
 715            let receipt = envelope.receipt();
 716            let handler = handler.clone();
 717            async move {
 718                let peer = session.peer.clone();
 719                let responded = Arc::new(AtomicBool::default());
 720                let response = Response {
 721                    peer: peer.clone(),
 722                    responded: responded.clone(),
 723                    receipt,
 724                };
 725                match (handler)(envelope.payload, response, session).await {
 726                    Ok(()) => {
 727                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 728                            Ok(())
 729                        } else {
 730                            Err(anyhow!("handler did not send a response"))?
 731                        }
 732                    }
 733                    Err(error) => {
 734                        let proto_err = match &error {
 735                            Error::Internal(err) => err.to_proto(),
 736                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 737                        };
 738                        peer.respond_with_error(receipt, proto_err)?;
 739                        Err(error)
 740                    }
 741                }
 742            }
 743        })
 744    }
 745
 746    pub fn handle_connection(
 747        self: &Arc<Self>,
 748        connection: Connection,
 749        address: String,
 750        principal: Principal,
 751        zed_version: ZedVersion,
 752        user_agent: Option<String>,
 753        geoip_country_code: Option<String>,
 754        system_id: Option<String>,
 755        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 756        executor: Executor,
 757        connection_guard: Option<ConnectionGuard>,
 758    ) -> impl Future<Output = ()> + use<> {
 759        let this = self.clone();
 760        let span = info_span!("handle connection", %address,
 761            connection_id=field::Empty,
 762            user_id=field::Empty,
 763            login=field::Empty,
 764            impersonator=field::Empty,
 765            user_agent=field::Empty,
 766            geoip_country_code=field::Empty
 767        );
 768        principal.update_span(&span);
 769        if let Some(user_agent) = user_agent {
 770            span.record("user_agent", user_agent);
 771        }
 772
 773        if let Some(country_code) = geoip_country_code.as_ref() {
 774            span.record("geoip_country_code", country_code);
 775        }
 776
 777        let mut teardown = self.teardown.subscribe();
 778        async move {
 779            if *teardown.borrow() {
 780                tracing::error!("server is tearing down");
 781                return;
 782            }
 783
 784            let (connection_id, handle_io, mut incoming_rx) =
 785                this.peer.add_connection(connection, {
 786                    let executor = executor.clone();
 787                    move |duration| executor.sleep(duration)
 788                });
 789            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 790
 791            tracing::info!("connection opened");
 792
 793            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 794            let http_client = match ReqwestClient::user_agent(&user_agent) {
 795                Ok(http_client) => Arc::new(http_client),
 796                Err(error) => {
 797                    tracing::error!(?error, "failed to create HTTP client");
 798                    return;
 799                }
 800            };
 801
 802            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 803                |supermaven_admin_api_key| {
 804                    Arc::new(SupermavenAdminApi::new(
 805                        supermaven_admin_api_key.to_string(),
 806                        http_client.clone(),
 807                    ))
 808                },
 809            );
 810
 811            let session = Session {
 812                principal: principal.clone(),
 813                connection_id,
 814                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 815                peer: this.peer.clone(),
 816                connection_pool: this.connection_pool.clone(),
 817                app_state: this.app_state.clone(),
 818                geoip_country_code,
 819                system_id,
 820                _executor: executor.clone(),
 821                supermaven_client,
 822            };
 823
 824            if let Err(error) = this
 825                .send_initial_client_update(
 826                    connection_id,
 827                    zed_version,
 828                    send_connection_id,
 829                    &session,
 830                )
 831                .await
 832            {
 833                tracing::error!(?error, "failed to send initial client update");
 834                return;
 835            }
 836            drop(connection_guard);
 837
 838            let handle_io = handle_io.fuse();
 839            futures::pin_mut!(handle_io);
 840
 841            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 842            // This prevents deadlocks when e.g., client A performs a request to client B and
 843            // client B performs a request to client A. If both clients stop processing further
 844            // messages until their respective request completes, they won't have a chance to
 845            // respond to the other client's request and cause a deadlock.
 846            //
 847            // This arrangement ensures we will attempt to process earlier messages first, but fall
 848            // back to processing messages arrived later in the spirit of making progress.
 849            const MAX_CONCURRENT_HANDLERS: usize = 256;
 850            let mut foreground_message_handlers = FuturesUnordered::new();
 851            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 852            let get_concurrent_handlers = {
 853                let concurrent_handlers = concurrent_handlers.clone();
 854                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 855            };
 856            loop {
 857                let next_message = async {
 858                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 859                    let message = incoming_rx.next().await;
 860                    // Cache the concurrent_handlers here, so that we know what the
 861                    // queue looks like as each handler starts
 862                    (permit, message, get_concurrent_handlers())
 863                }
 864                .fuse();
 865                futures::pin_mut!(next_message);
 866                futures::select_biased! {
 867                    _ = teardown.changed().fuse() => return,
 868                    result = handle_io => {
 869                        if let Err(error) = result {
 870                            tracing::error!(?error, "error handling I/O");
 871                        }
 872                        break;
 873                    }
 874                    _ = foreground_message_handlers.next() => {}
 875                    next_message = next_message => {
 876                        let (permit, message, concurrent_handlers) = next_message;
 877                        if let Some(message) = message {
 878                            let type_name = message.payload_type_name();
 879                            // note: we copy all the fields from the parent span so we can query them in the logs.
 880                            // (https://github.com/tokio-rs/tracing/issues/2670).
 881                            let span = tracing::info_span!("receive message",
 882                                %connection_id,
 883                                %address,
 884                                type_name,
 885                                concurrent_handlers,
 886                                user_id=field::Empty,
 887                                login=field::Empty,
 888                                impersonator=field::Empty,
 889                                multi_lsp_query_request=field::Empty,
 890                            );
 891                            principal.update_span(&span);
 892                            let span_enter = span.enter();
 893                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 894                                let is_background = message.is_background();
 895                                let handle_message = (handler)(message, session.clone());
 896                                drop(span_enter);
 897
 898                                let handle_message = async move {
 899                                    handle_message.await;
 900                                    drop(permit);
 901                                }.instrument(span);
 902                                if is_background {
 903                                    executor.spawn_detached(handle_message);
 904                                } else {
 905                                    foreground_message_handlers.push(handle_message);
 906                                }
 907                            } else {
 908                                tracing::error!("no message handler");
 909                            }
 910                        } else {
 911                            tracing::info!("connection closed");
 912                            break;
 913                        }
 914                    }
 915                }
 916            }
 917
 918            drop(foreground_message_handlers);
 919            let concurrent_handlers = get_concurrent_handlers();
 920            tracing::info!(concurrent_handlers, "signing out");
 921            if let Err(error) = connection_lost(session, teardown, executor).await {
 922                tracing::error!(?error, "error signing out");
 923            }
 924        }
 925        .instrument(span)
 926    }
 927
 928    async fn send_initial_client_update(
 929        &self,
 930        connection_id: ConnectionId,
 931        zed_version: ZedVersion,
 932        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 933        session: &Session,
 934    ) -> Result<()> {
 935        self.peer.send(
 936            connection_id,
 937            proto::Hello {
 938                peer_id: Some(connection_id.into()),
 939            },
 940        )?;
 941        tracing::info!("sent hello message");
 942        if let Some(send_connection_id) = send_connection_id.take() {
 943            let _ = send_connection_id.send(connection_id);
 944        }
 945
 946        match &session.principal {
 947            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 948                if !user.connected_once {
 949                    self.peer.send(connection_id, proto::ShowContacts {})?;
 950                    self.app_state
 951                        .db
 952                        .set_user_connected_once(user.id, true)
 953                        .await?;
 954                }
 955
 956                update_user_plan(session).await?;
 957
 958                let contacts = self.app_state.db.get_contacts(user.id).await?;
 959
 960                {
 961                    let mut pool = self.connection_pool.lock();
 962                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 963                    self.peer.send(
 964                        connection_id,
 965                        build_initial_contacts_update(contacts, &pool),
 966                    )?;
 967                }
 968
 969                if should_auto_subscribe_to_channels(zed_version) {
 970                    subscribe_user_to_channels(user.id, session).await?;
 971                }
 972
 973                if let Some(incoming_call) =
 974                    self.app_state.db.incoming_call_for_user(user.id).await?
 975                {
 976                    self.peer.send(connection_id, incoming_call)?;
 977                }
 978
 979                update_user_contacts(user.id, session).await?;
 980            }
 981        }
 982
 983        Ok(())
 984    }
 985
 986    pub async fn invite_code_redeemed(
 987        self: &Arc<Self>,
 988        inviter_id: UserId,
 989        invitee_id: UserId,
 990    ) -> Result<()> {
 991        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 992            if let Some(code) = &user.invite_code {
 993                let pool = self.connection_pool.lock();
 994                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 995                for connection_id in pool.user_connection_ids(inviter_id) {
 996                    self.peer.send(
 997                        connection_id,
 998                        proto::UpdateContacts {
 999                            contacts: vec![invitee_contact.clone()],
1000                            ..Default::default()
1001                        },
1002                    )?;
1003                    self.peer.send(
1004                        connection_id,
1005                        proto::UpdateInviteInfo {
1006                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1007                            count: user.invite_count as u32,
1008                        },
1009                    )?;
1010                }
1011            }
1012        }
1013        Ok(())
1014    }
1015
1016    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1017        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1018            if let Some(invite_code) = &user.invite_code {
1019                let pool = self.connection_pool.lock();
1020                for connection_id in pool.user_connection_ids(user_id) {
1021                    self.peer.send(
1022                        connection_id,
1023                        proto::UpdateInviteInfo {
1024                            url: format!(
1025                                "{}{}",
1026                                self.app_state.config.invite_link_prefix, invite_code
1027                            ),
1028                            count: user.invite_count as u32,
1029                        },
1030                    )?;
1031                }
1032            }
1033        }
1034        Ok(())
1035    }
1036
1037    pub async fn update_plan_for_user(
1038        self: &Arc<Self>,
1039        user_id: UserId,
1040        update_user_plan: proto::UpdateUserPlan,
1041    ) -> Result<()> {
1042        let pool = self.connection_pool.lock();
1043        for connection_id in pool.user_connection_ids(user_id) {
1044            self.peer
1045                .send(connection_id, update_user_plan.clone())
1046                .trace_err();
1047        }
1048
1049        Ok(())
1050    }
1051
1052    /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1053    /// message on the Collab server.
1054    ///
1055    /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1056    pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1057        let user = self
1058            .app_state
1059            .db
1060            .get_user_by_id(user_id)
1061            .await?
1062            .context("user not found")?;
1063
1064        let update_user_plan = make_update_user_plan_message(
1065            &user,
1066            user.admin,
1067            &self.app_state.db,
1068            self.app_state.llm_db.clone(),
1069        )
1070        .await?;
1071
1072        self.update_plan_for_user(user_id, update_user_plan).await
1073    }
1074
1075    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1076        let pool = self.connection_pool.lock();
1077        for connection_id in pool.user_connection_ids(user_id) {
1078            self.peer
1079                .send(connection_id, proto::RefreshLlmToken {})
1080                .trace_err();
1081        }
1082    }
1083
1084    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1085        ServerSnapshot {
1086            connection_pool: ConnectionPoolGuard {
1087                guard: self.connection_pool.lock(),
1088                _not_send: PhantomData,
1089            },
1090            peer: &self.peer,
1091        }
1092    }
1093}
1094
1095impl Deref for ConnectionPoolGuard<'_> {
1096    type Target = ConnectionPool;
1097
1098    fn deref(&self) -> &Self::Target {
1099        &self.guard
1100    }
1101}
1102
1103impl DerefMut for ConnectionPoolGuard<'_> {
1104    fn deref_mut(&mut self) -> &mut Self::Target {
1105        &mut self.guard
1106    }
1107}
1108
1109impl Drop for ConnectionPoolGuard<'_> {
1110    fn drop(&mut self) {
1111        #[cfg(test)]
1112        self.check_invariants();
1113    }
1114}
1115
1116fn broadcast<F>(
1117    sender_id: Option<ConnectionId>,
1118    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1119    mut f: F,
1120) where
1121    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1122{
1123    for receiver_id in receiver_ids {
1124        if Some(receiver_id) != sender_id {
1125            if let Err(error) = f(receiver_id) {
1126                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1127            }
1128        }
1129    }
1130}
1131
1132pub struct ProtocolVersion(u32);
1133
1134impl Header for ProtocolVersion {
1135    fn name() -> &'static HeaderName {
1136        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1137        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1138    }
1139
1140    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1141    where
1142        Self: Sized,
1143        I: Iterator<Item = &'i axum::http::HeaderValue>,
1144    {
1145        let version = values
1146            .next()
1147            .ok_or_else(axum::headers::Error::invalid)?
1148            .to_str()
1149            .map_err(|_| axum::headers::Error::invalid())?
1150            .parse()
1151            .map_err(|_| axum::headers::Error::invalid())?;
1152        Ok(Self(version))
1153    }
1154
1155    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1156        values.extend([self.0.to_string().parse().unwrap()]);
1157    }
1158}
1159
1160pub struct AppVersionHeader(SemanticVersion);
1161impl Header for AppVersionHeader {
1162    fn name() -> &'static HeaderName {
1163        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1164        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1165    }
1166
1167    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1168    where
1169        Self: Sized,
1170        I: Iterator<Item = &'i axum::http::HeaderValue>,
1171    {
1172        let version = values
1173            .next()
1174            .ok_or_else(axum::headers::Error::invalid)?
1175            .to_str()
1176            .map_err(|_| axum::headers::Error::invalid())?
1177            .parse()
1178            .map_err(|_| axum::headers::Error::invalid())?;
1179        Ok(Self(version))
1180    }
1181
1182    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1183        values.extend([self.0.to_string().parse().unwrap()]);
1184    }
1185}
1186
1187pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1188    Router::new()
1189        .route("/rpc", get(handle_websocket_request))
1190        .layer(
1191            ServiceBuilder::new()
1192                .layer(Extension(server.app_state.clone()))
1193                .layer(middleware::from_fn(auth::validate_header)),
1194        )
1195        .route("/metrics", get(handle_metrics))
1196        .layer(Extension(server))
1197}
1198
1199pub async fn handle_websocket_request(
1200    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1201    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1202    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1203    Extension(server): Extension<Arc<Server>>,
1204    Extension(principal): Extension<Principal>,
1205    user_agent: Option<TypedHeader<UserAgent>>,
1206    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1207    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1208    ws: WebSocketUpgrade,
1209) -> axum::response::Response {
1210    if protocol_version != rpc::PROTOCOL_VERSION {
1211        return (
1212            StatusCode::UPGRADE_REQUIRED,
1213            "client must be upgraded".to_string(),
1214        )
1215            .into_response();
1216    }
1217
1218    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1219        return (
1220            StatusCode::UPGRADE_REQUIRED,
1221            "no version header found".to_string(),
1222        )
1223            .into_response();
1224    };
1225
1226    if !version.can_collaborate() {
1227        return (
1228            StatusCode::UPGRADE_REQUIRED,
1229            "client must be upgraded".to_string(),
1230        )
1231            .into_response();
1232    }
1233
1234    let socket_address = socket_address.to_string();
1235
1236    // Acquire connection guard before WebSocket upgrade
1237    let connection_guard = match ConnectionGuard::try_acquire() {
1238        Ok(guard) => guard,
1239        Err(()) => {
1240            return (
1241                StatusCode::SERVICE_UNAVAILABLE,
1242                "Too many concurrent connections",
1243            )
1244                .into_response();
1245        }
1246    };
1247
1248    ws.on_upgrade(move |socket| {
1249        let socket = socket
1250            .map_ok(to_tungstenite_message)
1251            .err_into()
1252            .with(|message| async move { to_axum_message(message) });
1253        let connection = Connection::new(Box::pin(socket));
1254        async move {
1255            server
1256                .handle_connection(
1257                    connection,
1258                    socket_address,
1259                    principal,
1260                    version,
1261                    user_agent.map(|header| header.to_string()),
1262                    country_code_header.map(|header| header.to_string()),
1263                    system_id_header.map(|header| header.to_string()),
1264                    None,
1265                    Executor::Production,
1266                    Some(connection_guard),
1267                )
1268                .await;
1269        }
1270    })
1271}
1272
1273pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1274    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1275    let connections_metric = CONNECTIONS_METRIC
1276        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1277
1278    let connections = server
1279        .connection_pool
1280        .lock()
1281        .connections()
1282        .filter(|connection| !connection.admin)
1283        .count();
1284    connections_metric.set(connections as _);
1285
1286    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1287    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1288        register_int_gauge!(
1289            "shared_projects",
1290            "number of open projects with one or more guests"
1291        )
1292        .unwrap()
1293    });
1294
1295    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1296    shared_projects_metric.set(shared_projects as _);
1297
1298    let encoder = prometheus::TextEncoder::new();
1299    let metric_families = prometheus::gather();
1300    let encoded_metrics = encoder
1301        .encode_to_string(&metric_families)
1302        .map_err(|err| anyhow!("{err}"))?;
1303    Ok(encoded_metrics)
1304}
1305
1306#[instrument(err, skip(executor))]
1307async fn connection_lost(
1308    session: Session,
1309    mut teardown: watch::Receiver<bool>,
1310    executor: Executor,
1311) -> Result<()> {
1312    session.peer.disconnect(session.connection_id);
1313    session
1314        .connection_pool()
1315        .await
1316        .remove_connection(session.connection_id)?;
1317
1318    session
1319        .db()
1320        .await
1321        .connection_lost(session.connection_id)
1322        .await
1323        .trace_err();
1324
1325    futures::select_biased! {
1326        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1327
1328            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1329            leave_room_for_session(&session, session.connection_id).await.trace_err();
1330            leave_channel_buffers_for_session(&session)
1331                .await
1332                .trace_err();
1333
1334            if !session
1335                .connection_pool()
1336                .await
1337                .is_user_online(session.user_id())
1338            {
1339                let db = session.db().await;
1340                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1341                    room_updated(&room, &session.peer);
1342                }
1343            }
1344
1345            update_user_contacts(session.user_id(), &session).await?;
1346        },
1347        _ = teardown.changed().fuse() => {}
1348    }
1349
1350    Ok(())
1351}
1352
1353/// Acknowledges a ping from a client, used to keep the connection alive.
1354async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1355    response.send(proto::Ack {})?;
1356    Ok(())
1357}
1358
1359/// Creates a new room for calling (outside of channels)
1360async fn create_room(
1361    _request: proto::CreateRoom,
1362    response: Response<proto::CreateRoom>,
1363    session: Session,
1364) -> Result<()> {
1365    let livekit_room = nanoid::nanoid!(30);
1366
1367    let live_kit_connection_info = util::maybe!(async {
1368        let live_kit = session.app_state.livekit_client.as_ref();
1369        let live_kit = live_kit?;
1370        let user_id = session.user_id().to_string();
1371
1372        let token = live_kit
1373            .room_token(&livekit_room, &user_id.to_string())
1374            .trace_err()?;
1375
1376        Some(proto::LiveKitConnectionInfo {
1377            server_url: live_kit.url().into(),
1378            token,
1379            can_publish: true,
1380        })
1381    })
1382    .await;
1383
1384    let room = session
1385        .db()
1386        .await
1387        .create_room(session.user_id(), session.connection_id, &livekit_room)
1388        .await?;
1389
1390    response.send(proto::CreateRoomResponse {
1391        room: Some(room.clone()),
1392        live_kit_connection_info,
1393    })?;
1394
1395    update_user_contacts(session.user_id(), &session).await?;
1396    Ok(())
1397}
1398
1399/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1400async fn join_room(
1401    request: proto::JoinRoom,
1402    response: Response<proto::JoinRoom>,
1403    session: Session,
1404) -> Result<()> {
1405    let room_id = RoomId::from_proto(request.id);
1406
1407    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1408
1409    if let Some(channel_id) = channel_id {
1410        return join_channel_internal(channel_id, Box::new(response), session).await;
1411    }
1412
1413    let joined_room = {
1414        let room = session
1415            .db()
1416            .await
1417            .join_room(room_id, session.user_id(), session.connection_id)
1418            .await?;
1419        room_updated(&room.room, &session.peer);
1420        room.into_inner()
1421    };
1422
1423    for connection_id in session
1424        .connection_pool()
1425        .await
1426        .user_connection_ids(session.user_id())
1427    {
1428        session
1429            .peer
1430            .send(
1431                connection_id,
1432                proto::CallCanceled {
1433                    room_id: room_id.to_proto(),
1434                },
1435            )
1436            .trace_err();
1437    }
1438
1439    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1440    {
1441        live_kit
1442            .room_token(
1443                &joined_room.room.livekit_room,
1444                &session.user_id().to_string(),
1445            )
1446            .trace_err()
1447            .map(|token| proto::LiveKitConnectionInfo {
1448                server_url: live_kit.url().into(),
1449                token,
1450                can_publish: true,
1451            })
1452    } else {
1453        None
1454    };
1455
1456    response.send(proto::JoinRoomResponse {
1457        room: Some(joined_room.room),
1458        channel_id: None,
1459        live_kit_connection_info,
1460    })?;
1461
1462    update_user_contacts(session.user_id(), &session).await?;
1463    Ok(())
1464}
1465
1466/// Rejoin room is used to reconnect to a room after connection errors.
1467async fn rejoin_room(
1468    request: proto::RejoinRoom,
1469    response: Response<proto::RejoinRoom>,
1470    session: Session,
1471) -> Result<()> {
1472    let room;
1473    let channel;
1474    {
1475        let mut rejoined_room = session
1476            .db()
1477            .await
1478            .rejoin_room(request, session.user_id(), session.connection_id)
1479            .await?;
1480
1481        response.send(proto::RejoinRoomResponse {
1482            room: Some(rejoined_room.room.clone()),
1483            reshared_projects: rejoined_room
1484                .reshared_projects
1485                .iter()
1486                .map(|project| proto::ResharedProject {
1487                    id: project.id.to_proto(),
1488                    collaborators: project
1489                        .collaborators
1490                        .iter()
1491                        .map(|collaborator| collaborator.to_proto())
1492                        .collect(),
1493                })
1494                .collect(),
1495            rejoined_projects: rejoined_room
1496                .rejoined_projects
1497                .iter()
1498                .map(|rejoined_project| rejoined_project.to_proto())
1499                .collect(),
1500        })?;
1501        room_updated(&rejoined_room.room, &session.peer);
1502
1503        for project in &rejoined_room.reshared_projects {
1504            for collaborator in &project.collaborators {
1505                session
1506                    .peer
1507                    .send(
1508                        collaborator.connection_id,
1509                        proto::UpdateProjectCollaborator {
1510                            project_id: project.id.to_proto(),
1511                            old_peer_id: Some(project.old_connection_id.into()),
1512                            new_peer_id: Some(session.connection_id.into()),
1513                        },
1514                    )
1515                    .trace_err();
1516            }
1517
1518            broadcast(
1519                Some(session.connection_id),
1520                project
1521                    .collaborators
1522                    .iter()
1523                    .map(|collaborator| collaborator.connection_id),
1524                |connection_id| {
1525                    session.peer.forward_send(
1526                        session.connection_id,
1527                        connection_id,
1528                        proto::UpdateProject {
1529                            project_id: project.id.to_proto(),
1530                            worktrees: project.worktrees.clone(),
1531                        },
1532                    )
1533                },
1534            );
1535        }
1536
1537        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1538
1539        let rejoined_room = rejoined_room.into_inner();
1540
1541        room = rejoined_room.room;
1542        channel = rejoined_room.channel;
1543    }
1544
1545    if let Some(channel) = channel {
1546        channel_updated(
1547            &channel,
1548            &room,
1549            &session.peer,
1550            &*session.connection_pool().await,
1551        );
1552    }
1553
1554    update_user_contacts(session.user_id(), &session).await?;
1555    Ok(())
1556}
1557
1558fn notify_rejoined_projects(
1559    rejoined_projects: &mut Vec<RejoinedProject>,
1560    session: &Session,
1561) -> Result<()> {
1562    for project in rejoined_projects.iter() {
1563        for collaborator in &project.collaborators {
1564            session
1565                .peer
1566                .send(
1567                    collaborator.connection_id,
1568                    proto::UpdateProjectCollaborator {
1569                        project_id: project.id.to_proto(),
1570                        old_peer_id: Some(project.old_connection_id.into()),
1571                        new_peer_id: Some(session.connection_id.into()),
1572                    },
1573                )
1574                .trace_err();
1575        }
1576    }
1577
1578    for project in rejoined_projects {
1579        for worktree in mem::take(&mut project.worktrees) {
1580            // Stream this worktree's entries.
1581            let message = proto::UpdateWorktree {
1582                project_id: project.id.to_proto(),
1583                worktree_id: worktree.id,
1584                abs_path: worktree.abs_path.clone(),
1585                root_name: worktree.root_name,
1586                updated_entries: worktree.updated_entries,
1587                removed_entries: worktree.removed_entries,
1588                scan_id: worktree.scan_id,
1589                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1590                updated_repositories: worktree.updated_repositories,
1591                removed_repositories: worktree.removed_repositories,
1592            };
1593            for update in proto::split_worktree_update(message) {
1594                session.peer.send(session.connection_id, update)?;
1595            }
1596
1597            // Stream this worktree's diagnostics.
1598            for summary in worktree.diagnostic_summaries {
1599                session.peer.send(
1600                    session.connection_id,
1601                    proto::UpdateDiagnosticSummary {
1602                        project_id: project.id.to_proto(),
1603                        worktree_id: worktree.id,
1604                        summary: Some(summary),
1605                    },
1606                )?;
1607            }
1608
1609            for settings_file in worktree.settings_files {
1610                session.peer.send(
1611                    session.connection_id,
1612                    proto::UpdateWorktreeSettings {
1613                        project_id: project.id.to_proto(),
1614                        worktree_id: worktree.id,
1615                        path: settings_file.path,
1616                        content: Some(settings_file.content),
1617                        kind: Some(settings_file.kind.to_proto().into()),
1618                    },
1619                )?;
1620            }
1621        }
1622
1623        for repository in mem::take(&mut project.updated_repositories) {
1624            for update in split_repository_update(repository) {
1625                session.peer.send(session.connection_id, update)?;
1626            }
1627        }
1628
1629        for id in mem::take(&mut project.removed_repositories) {
1630            session.peer.send(
1631                session.connection_id,
1632                proto::RemoveRepository {
1633                    project_id: project.id.to_proto(),
1634                    id,
1635                },
1636            )?;
1637        }
1638    }
1639
1640    Ok(())
1641}
1642
1643/// leave room disconnects from the room.
1644async fn leave_room(
1645    _: proto::LeaveRoom,
1646    response: Response<proto::LeaveRoom>,
1647    session: Session,
1648) -> Result<()> {
1649    leave_room_for_session(&session, session.connection_id).await?;
1650    response.send(proto::Ack {})?;
1651    Ok(())
1652}
1653
1654/// Updates the permissions of someone else in the room.
1655async fn set_room_participant_role(
1656    request: proto::SetRoomParticipantRole,
1657    response: Response<proto::SetRoomParticipantRole>,
1658    session: Session,
1659) -> Result<()> {
1660    let user_id = UserId::from_proto(request.user_id);
1661    let role = ChannelRole::from(request.role());
1662
1663    let (livekit_room, can_publish) = {
1664        let room = session
1665            .db()
1666            .await
1667            .set_room_participant_role(
1668                session.user_id(),
1669                RoomId::from_proto(request.room_id),
1670                user_id,
1671                role,
1672            )
1673            .await?;
1674
1675        let livekit_room = room.livekit_room.clone();
1676        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1677        room_updated(&room, &session.peer);
1678        (livekit_room, can_publish)
1679    };
1680
1681    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1682        live_kit
1683            .update_participant(
1684                livekit_room.clone(),
1685                request.user_id.to_string(),
1686                livekit_api::proto::ParticipantPermission {
1687                    can_subscribe: true,
1688                    can_publish,
1689                    can_publish_data: can_publish,
1690                    hidden: false,
1691                    recorder: false,
1692                },
1693            )
1694            .await
1695            .trace_err();
1696    }
1697
1698    response.send(proto::Ack {})?;
1699    Ok(())
1700}
1701
1702/// Call someone else into the current room
1703async fn call(
1704    request: proto::Call,
1705    response: Response<proto::Call>,
1706    session: Session,
1707) -> Result<()> {
1708    let room_id = RoomId::from_proto(request.room_id);
1709    let calling_user_id = session.user_id();
1710    let calling_connection_id = session.connection_id;
1711    let called_user_id = UserId::from_proto(request.called_user_id);
1712    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1713    if !session
1714        .db()
1715        .await
1716        .has_contact(calling_user_id, called_user_id)
1717        .await?
1718    {
1719        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1720    }
1721
1722    let incoming_call = {
1723        let (room, incoming_call) = &mut *session
1724            .db()
1725            .await
1726            .call(
1727                room_id,
1728                calling_user_id,
1729                calling_connection_id,
1730                called_user_id,
1731                initial_project_id,
1732            )
1733            .await?;
1734        room_updated(room, &session.peer);
1735        mem::take(incoming_call)
1736    };
1737    update_user_contacts(called_user_id, &session).await?;
1738
1739    let mut calls = session
1740        .connection_pool()
1741        .await
1742        .user_connection_ids(called_user_id)
1743        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1744        .collect::<FuturesUnordered<_>>();
1745
1746    while let Some(call_response) = calls.next().await {
1747        match call_response.as_ref() {
1748            Ok(_) => {
1749                response.send(proto::Ack {})?;
1750                return Ok(());
1751            }
1752            Err(_) => {
1753                call_response.trace_err();
1754            }
1755        }
1756    }
1757
1758    {
1759        let room = session
1760            .db()
1761            .await
1762            .call_failed(room_id, called_user_id)
1763            .await?;
1764        room_updated(&room, &session.peer);
1765    }
1766    update_user_contacts(called_user_id, &session).await?;
1767
1768    Err(anyhow!("failed to ring user"))?
1769}
1770
1771/// Cancel an outgoing call.
1772async fn cancel_call(
1773    request: proto::CancelCall,
1774    response: Response<proto::CancelCall>,
1775    session: Session,
1776) -> Result<()> {
1777    let called_user_id = UserId::from_proto(request.called_user_id);
1778    let room_id = RoomId::from_proto(request.room_id);
1779    {
1780        let room = session
1781            .db()
1782            .await
1783            .cancel_call(room_id, session.connection_id, called_user_id)
1784            .await?;
1785        room_updated(&room, &session.peer);
1786    }
1787
1788    for connection_id in session
1789        .connection_pool()
1790        .await
1791        .user_connection_ids(called_user_id)
1792    {
1793        session
1794            .peer
1795            .send(
1796                connection_id,
1797                proto::CallCanceled {
1798                    room_id: room_id.to_proto(),
1799                },
1800            )
1801            .trace_err();
1802    }
1803    response.send(proto::Ack {})?;
1804
1805    update_user_contacts(called_user_id, &session).await?;
1806    Ok(())
1807}
1808
1809/// Decline an incoming call.
1810async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1811    let room_id = RoomId::from_proto(message.room_id);
1812    {
1813        let room = session
1814            .db()
1815            .await
1816            .decline_call(Some(room_id), session.user_id())
1817            .await?
1818            .context("declining call")?;
1819        room_updated(&room, &session.peer);
1820    }
1821
1822    for connection_id in session
1823        .connection_pool()
1824        .await
1825        .user_connection_ids(session.user_id())
1826    {
1827        session
1828            .peer
1829            .send(
1830                connection_id,
1831                proto::CallCanceled {
1832                    room_id: room_id.to_proto(),
1833                },
1834            )
1835            .trace_err();
1836    }
1837    update_user_contacts(session.user_id(), &session).await?;
1838    Ok(())
1839}
1840
1841/// Updates other participants in the room with your current location.
1842async fn update_participant_location(
1843    request: proto::UpdateParticipantLocation,
1844    response: Response<proto::UpdateParticipantLocation>,
1845    session: Session,
1846) -> Result<()> {
1847    let room_id = RoomId::from_proto(request.room_id);
1848    let location = request.location.context("invalid location")?;
1849
1850    let db = session.db().await;
1851    let room = db
1852        .update_room_participant_location(room_id, session.connection_id, location)
1853        .await?;
1854
1855    room_updated(&room, &session.peer);
1856    response.send(proto::Ack {})?;
1857    Ok(())
1858}
1859
1860/// Share a project into the room.
1861async fn share_project(
1862    request: proto::ShareProject,
1863    response: Response<proto::ShareProject>,
1864    session: Session,
1865) -> Result<()> {
1866    let (project_id, room) = &*session
1867        .db()
1868        .await
1869        .share_project(
1870            RoomId::from_proto(request.room_id),
1871            session.connection_id,
1872            &request.worktrees,
1873            request.is_ssh_project,
1874        )
1875        .await?;
1876    response.send(proto::ShareProjectResponse {
1877        project_id: project_id.to_proto(),
1878    })?;
1879    room_updated(room, &session.peer);
1880
1881    Ok(())
1882}
1883
1884/// Unshare a project from the room.
1885async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1886    let project_id = ProjectId::from_proto(message.project_id);
1887    unshare_project_internal(project_id, session.connection_id, &session).await
1888}
1889
1890async fn unshare_project_internal(
1891    project_id: ProjectId,
1892    connection_id: ConnectionId,
1893    session: &Session,
1894) -> Result<()> {
1895    let delete = {
1896        let room_guard = session
1897            .db()
1898            .await
1899            .unshare_project(project_id, connection_id)
1900            .await?;
1901
1902        let (delete, room, guest_connection_ids) = &*room_guard;
1903
1904        let message = proto::UnshareProject {
1905            project_id: project_id.to_proto(),
1906        };
1907
1908        broadcast(
1909            Some(connection_id),
1910            guest_connection_ids.iter().copied(),
1911            |conn_id| session.peer.send(conn_id, message.clone()),
1912        );
1913        if let Some(room) = room {
1914            room_updated(room, &session.peer);
1915        }
1916
1917        *delete
1918    };
1919
1920    if delete {
1921        let db = session.db().await;
1922        db.delete_project(project_id).await?;
1923    }
1924
1925    Ok(())
1926}
1927
1928/// Join someone elses shared project.
1929async fn join_project(
1930    request: proto::JoinProject,
1931    response: Response<proto::JoinProject>,
1932    session: Session,
1933) -> Result<()> {
1934    let project_id = ProjectId::from_proto(request.project_id);
1935
1936    tracing::info!(%project_id, "join project");
1937
1938    let db = session.db().await;
1939    let (project, replica_id) = &mut *db
1940        .join_project(
1941            project_id,
1942            session.connection_id,
1943            session.user_id(),
1944            request.committer_name.clone(),
1945            request.committer_email.clone(),
1946        )
1947        .await?;
1948    drop(db);
1949    tracing::info!(%project_id, "join remote project");
1950    let collaborators = project
1951        .collaborators
1952        .iter()
1953        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1954        .map(|collaborator| collaborator.to_proto())
1955        .collect::<Vec<_>>();
1956    let project_id = project.id;
1957    let guest_user_id = session.user_id();
1958
1959    let worktrees = project
1960        .worktrees
1961        .iter()
1962        .map(|(id, worktree)| proto::WorktreeMetadata {
1963            id: *id,
1964            root_name: worktree.root_name.clone(),
1965            visible: worktree.visible,
1966            abs_path: worktree.abs_path.clone(),
1967        })
1968        .collect::<Vec<_>>();
1969
1970    let add_project_collaborator = proto::AddProjectCollaborator {
1971        project_id: project_id.to_proto(),
1972        collaborator: Some(proto::Collaborator {
1973            peer_id: Some(session.connection_id.into()),
1974            replica_id: replica_id.0 as u32,
1975            user_id: guest_user_id.to_proto(),
1976            is_host: false,
1977            committer_name: request.committer_name.clone(),
1978            committer_email: request.committer_email.clone(),
1979        }),
1980    };
1981
1982    for collaborator in &collaborators {
1983        session
1984            .peer
1985            .send(
1986                collaborator.peer_id.unwrap().into(),
1987                add_project_collaborator.clone(),
1988            )
1989            .trace_err();
1990    }
1991
1992    // First, we send the metadata associated with each worktree.
1993    let (language_servers, language_server_capabilities) = project
1994        .language_servers
1995        .clone()
1996        .into_iter()
1997        .map(|server| (server.server, server.capabilities))
1998        .unzip();
1999    response.send(proto::JoinProjectResponse {
2000        project_id: project.id.0 as u64,
2001        worktrees: worktrees.clone(),
2002        replica_id: replica_id.0 as u32,
2003        collaborators: collaborators.clone(),
2004        language_servers,
2005        language_server_capabilities,
2006        role: project.role.into(),
2007    })?;
2008
2009    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2010        // Stream this worktree's entries.
2011        let message = proto::UpdateWorktree {
2012            project_id: project_id.to_proto(),
2013            worktree_id,
2014            abs_path: worktree.abs_path.clone(),
2015            root_name: worktree.root_name,
2016            updated_entries: worktree.entries,
2017            removed_entries: Default::default(),
2018            scan_id: worktree.scan_id,
2019            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2020            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2021            removed_repositories: Default::default(),
2022        };
2023        for update in proto::split_worktree_update(message) {
2024            session.peer.send(session.connection_id, update.clone())?;
2025        }
2026
2027        // Stream this worktree's diagnostics.
2028        for summary in worktree.diagnostic_summaries {
2029            session.peer.send(
2030                session.connection_id,
2031                proto::UpdateDiagnosticSummary {
2032                    project_id: project_id.to_proto(),
2033                    worktree_id: worktree.id,
2034                    summary: Some(summary),
2035                },
2036            )?;
2037        }
2038
2039        for settings_file in worktree.settings_files {
2040            session.peer.send(
2041                session.connection_id,
2042                proto::UpdateWorktreeSettings {
2043                    project_id: project_id.to_proto(),
2044                    worktree_id: worktree.id,
2045                    path: settings_file.path,
2046                    content: Some(settings_file.content),
2047                    kind: Some(settings_file.kind.to_proto() as i32),
2048                },
2049            )?;
2050        }
2051    }
2052
2053    for repository in mem::take(&mut project.repositories) {
2054        for update in split_repository_update(repository) {
2055            session.peer.send(session.connection_id, update)?;
2056        }
2057    }
2058
2059    for language_server in &project.language_servers {
2060        session.peer.send(
2061            session.connection_id,
2062            proto::UpdateLanguageServer {
2063                project_id: project_id.to_proto(),
2064                server_name: Some(language_server.server.name.clone()),
2065                language_server_id: language_server.server.id,
2066                variant: Some(
2067                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2068                        proto::LspDiskBasedDiagnosticsUpdated {},
2069                    ),
2070                ),
2071            },
2072        )?;
2073    }
2074
2075    Ok(())
2076}
2077
2078/// Leave someone elses shared project.
2079async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2080    let sender_id = session.connection_id;
2081    let project_id = ProjectId::from_proto(request.project_id);
2082    let db = session.db().await;
2083
2084    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2085    tracing::info!(
2086        %project_id,
2087        "leave project"
2088    );
2089
2090    project_left(project, &session);
2091    if let Some(room) = room {
2092        room_updated(room, &session.peer);
2093    }
2094
2095    Ok(())
2096}
2097
2098/// Updates other participants with changes to the project
2099async fn update_project(
2100    request: proto::UpdateProject,
2101    response: Response<proto::UpdateProject>,
2102    session: Session,
2103) -> Result<()> {
2104    let project_id = ProjectId::from_proto(request.project_id);
2105    let (room, guest_connection_ids) = &*session
2106        .db()
2107        .await
2108        .update_project(project_id, session.connection_id, &request.worktrees)
2109        .await?;
2110    broadcast(
2111        Some(session.connection_id),
2112        guest_connection_ids.iter().copied(),
2113        |connection_id| {
2114            session
2115                .peer
2116                .forward_send(session.connection_id, connection_id, request.clone())
2117        },
2118    );
2119    if let Some(room) = room {
2120        room_updated(room, &session.peer);
2121    }
2122    response.send(proto::Ack {})?;
2123
2124    Ok(())
2125}
2126
2127/// Updates other participants with changes to the worktree
2128async fn update_worktree(
2129    request: proto::UpdateWorktree,
2130    response: Response<proto::UpdateWorktree>,
2131    session: Session,
2132) -> Result<()> {
2133    let guest_connection_ids = session
2134        .db()
2135        .await
2136        .update_worktree(&request, session.connection_id)
2137        .await?;
2138
2139    broadcast(
2140        Some(session.connection_id),
2141        guest_connection_ids.iter().copied(),
2142        |connection_id| {
2143            session
2144                .peer
2145                .forward_send(session.connection_id, connection_id, request.clone())
2146        },
2147    );
2148    response.send(proto::Ack {})?;
2149    Ok(())
2150}
2151
2152async fn update_repository(
2153    request: proto::UpdateRepository,
2154    response: Response<proto::UpdateRepository>,
2155    session: Session,
2156) -> Result<()> {
2157    let guest_connection_ids = session
2158        .db()
2159        .await
2160        .update_repository(&request, session.connection_id)
2161        .await?;
2162
2163    broadcast(
2164        Some(session.connection_id),
2165        guest_connection_ids.iter().copied(),
2166        |connection_id| {
2167            session
2168                .peer
2169                .forward_send(session.connection_id, connection_id, request.clone())
2170        },
2171    );
2172    response.send(proto::Ack {})?;
2173    Ok(())
2174}
2175
2176async fn remove_repository(
2177    request: proto::RemoveRepository,
2178    response: Response<proto::RemoveRepository>,
2179    session: Session,
2180) -> Result<()> {
2181    let guest_connection_ids = session
2182        .db()
2183        .await
2184        .remove_repository(&request, session.connection_id)
2185        .await?;
2186
2187    broadcast(
2188        Some(session.connection_id),
2189        guest_connection_ids.iter().copied(),
2190        |connection_id| {
2191            session
2192                .peer
2193                .forward_send(session.connection_id, connection_id, request.clone())
2194        },
2195    );
2196    response.send(proto::Ack {})?;
2197    Ok(())
2198}
2199
2200/// Updates other participants with changes to the diagnostics
2201async fn update_diagnostic_summary(
2202    message: proto::UpdateDiagnosticSummary,
2203    session: Session,
2204) -> Result<()> {
2205    let guest_connection_ids = session
2206        .db()
2207        .await
2208        .update_diagnostic_summary(&message, session.connection_id)
2209        .await?;
2210
2211    broadcast(
2212        Some(session.connection_id),
2213        guest_connection_ids.iter().copied(),
2214        |connection_id| {
2215            session
2216                .peer
2217                .forward_send(session.connection_id, connection_id, message.clone())
2218        },
2219    );
2220
2221    Ok(())
2222}
2223
2224/// Updates other participants with changes to the worktree settings
2225async fn update_worktree_settings(
2226    message: proto::UpdateWorktreeSettings,
2227    session: Session,
2228) -> Result<()> {
2229    let guest_connection_ids = session
2230        .db()
2231        .await
2232        .update_worktree_settings(&message, session.connection_id)
2233        .await?;
2234
2235    broadcast(
2236        Some(session.connection_id),
2237        guest_connection_ids.iter().copied(),
2238        |connection_id| {
2239            session
2240                .peer
2241                .forward_send(session.connection_id, connection_id, message.clone())
2242        },
2243    );
2244
2245    Ok(())
2246}
2247
2248/// Notify other participants that a language server has started.
2249async fn start_language_server(
2250    request: proto::StartLanguageServer,
2251    session: Session,
2252) -> Result<()> {
2253    let guest_connection_ids = session
2254        .db()
2255        .await
2256        .start_language_server(&request, session.connection_id)
2257        .await?;
2258
2259    broadcast(
2260        Some(session.connection_id),
2261        guest_connection_ids.iter().copied(),
2262        |connection_id| {
2263            session
2264                .peer
2265                .forward_send(session.connection_id, connection_id, request.clone())
2266        },
2267    );
2268    Ok(())
2269}
2270
2271/// Notify other participants that a language server has changed.
2272async fn update_language_server(
2273    request: proto::UpdateLanguageServer,
2274    session: Session,
2275) -> Result<()> {
2276    let project_id = ProjectId::from_proto(request.project_id);
2277    let db = session.db().await;
2278
2279    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2280    {
2281        if let Some(capabilities) = update.capabilities.clone() {
2282            db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2283                .await?;
2284        }
2285    }
2286
2287    let project_connection_ids = db
2288        .project_connection_ids(project_id, session.connection_id, true)
2289        .await?;
2290    broadcast(
2291        Some(session.connection_id),
2292        project_connection_ids.iter().copied(),
2293        |connection_id| {
2294            session
2295                .peer
2296                .forward_send(session.connection_id, connection_id, request.clone())
2297        },
2298    );
2299    Ok(())
2300}
2301
2302/// forward a project request to the host. These requests should be read only
2303/// as guests are allowed to send them.
2304async fn forward_read_only_project_request<T>(
2305    request: T,
2306    response: Response<T>,
2307    session: Session,
2308) -> Result<()>
2309where
2310    T: EntityMessage + RequestMessage,
2311{
2312    let project_id = ProjectId::from_proto(request.remote_entity_id());
2313    let host_connection_id = session
2314        .db()
2315        .await
2316        .host_for_read_only_project_request(project_id, session.connection_id)
2317        .await?;
2318    let payload = session
2319        .peer
2320        .forward_request(session.connection_id, host_connection_id, request)
2321        .await?;
2322    response.send(payload)?;
2323    Ok(())
2324}
2325
2326/// forward a project request to the host. These requests are disallowed
2327/// for guests.
2328async fn forward_mutating_project_request<T>(
2329    request: T,
2330    response: Response<T>,
2331    session: Session,
2332) -> Result<()>
2333where
2334    T: EntityMessage + RequestMessage,
2335{
2336    let project_id = ProjectId::from_proto(request.remote_entity_id());
2337
2338    let host_connection_id = session
2339        .db()
2340        .await
2341        .host_for_mutating_project_request(project_id, session.connection_id)
2342        .await?;
2343    let payload = session
2344        .peer
2345        .forward_request(session.connection_id, host_connection_id, request)
2346        .await?;
2347    response.send(payload)?;
2348    Ok(())
2349}
2350
2351async fn multi_lsp_query(
2352    request: MultiLspQuery,
2353    response: Response<MultiLspQuery>,
2354    session: Session,
2355) -> Result<()> {
2356    tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2357    tracing::info!("multi_lsp_query message received");
2358    forward_mutating_project_request(request, response, session).await
2359}
2360
2361/// Notify other participants that a new buffer has been created
2362async fn create_buffer_for_peer(
2363    request: proto::CreateBufferForPeer,
2364    session: Session,
2365) -> Result<()> {
2366    session
2367        .db()
2368        .await
2369        .check_user_is_project_host(
2370            ProjectId::from_proto(request.project_id),
2371            session.connection_id,
2372        )
2373        .await?;
2374    let peer_id = request.peer_id.context("invalid peer id")?;
2375    session
2376        .peer
2377        .forward_send(session.connection_id, peer_id.into(), request)?;
2378    Ok(())
2379}
2380
2381/// Notify other participants that a buffer has been updated. This is
2382/// allowed for guests as long as the update is limited to selections.
2383async fn update_buffer(
2384    request: proto::UpdateBuffer,
2385    response: Response<proto::UpdateBuffer>,
2386    session: Session,
2387) -> Result<()> {
2388    let project_id = ProjectId::from_proto(request.project_id);
2389    let mut capability = Capability::ReadOnly;
2390
2391    for op in request.operations.iter() {
2392        match op.variant {
2393            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2394            Some(_) => capability = Capability::ReadWrite,
2395        }
2396    }
2397
2398    let host = {
2399        let guard = session
2400            .db()
2401            .await
2402            .connections_for_buffer_update(project_id, session.connection_id, capability)
2403            .await?;
2404
2405        let (host, guests) = &*guard;
2406
2407        broadcast(
2408            Some(session.connection_id),
2409            guests.clone(),
2410            |connection_id| {
2411                session
2412                    .peer
2413                    .forward_send(session.connection_id, connection_id, request.clone())
2414            },
2415        );
2416
2417        *host
2418    };
2419
2420    if host != session.connection_id {
2421        session
2422            .peer
2423            .forward_request(session.connection_id, host, request.clone())
2424            .await?;
2425    }
2426
2427    response.send(proto::Ack {})?;
2428    Ok(())
2429}
2430
2431async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2432    let project_id = ProjectId::from_proto(message.project_id);
2433
2434    let operation = message.operation.as_ref().context("invalid operation")?;
2435    let capability = match operation.variant.as_ref() {
2436        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2437            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2438                match buffer_op.variant {
2439                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2440                        Capability::ReadOnly
2441                    }
2442                    _ => Capability::ReadWrite,
2443                }
2444            } else {
2445                Capability::ReadWrite
2446            }
2447        }
2448        Some(_) => Capability::ReadWrite,
2449        None => Capability::ReadOnly,
2450    };
2451
2452    let guard = session
2453        .db()
2454        .await
2455        .connections_for_buffer_update(project_id, session.connection_id, capability)
2456        .await?;
2457
2458    let (host, guests) = &*guard;
2459
2460    broadcast(
2461        Some(session.connection_id),
2462        guests.iter().chain([host]).copied(),
2463        |connection_id| {
2464            session
2465                .peer
2466                .forward_send(session.connection_id, connection_id, message.clone())
2467        },
2468    );
2469
2470    Ok(())
2471}
2472
2473/// Notify other participants that a project has been updated.
2474async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2475    request: T,
2476    session: Session,
2477) -> Result<()> {
2478    let project_id = ProjectId::from_proto(request.remote_entity_id());
2479    let project_connection_ids = session
2480        .db()
2481        .await
2482        .project_connection_ids(project_id, session.connection_id, false)
2483        .await?;
2484
2485    broadcast(
2486        Some(session.connection_id),
2487        project_connection_ids.iter().copied(),
2488        |connection_id| {
2489            session
2490                .peer
2491                .forward_send(session.connection_id, connection_id, request.clone())
2492        },
2493    );
2494    Ok(())
2495}
2496
2497/// Start following another user in a call.
2498async fn follow(
2499    request: proto::Follow,
2500    response: Response<proto::Follow>,
2501    session: Session,
2502) -> Result<()> {
2503    let room_id = RoomId::from_proto(request.room_id);
2504    let project_id = request.project_id.map(ProjectId::from_proto);
2505    let leader_id = request.leader_id.context("invalid leader id")?.into();
2506    let follower_id = session.connection_id;
2507
2508    session
2509        .db()
2510        .await
2511        .check_room_participants(room_id, leader_id, session.connection_id)
2512        .await?;
2513
2514    let response_payload = session
2515        .peer
2516        .forward_request(session.connection_id, leader_id, request)
2517        .await?;
2518    response.send(response_payload)?;
2519
2520    if let Some(project_id) = project_id {
2521        let room = session
2522            .db()
2523            .await
2524            .follow(room_id, project_id, leader_id, follower_id)
2525            .await?;
2526        room_updated(&room, &session.peer);
2527    }
2528
2529    Ok(())
2530}
2531
2532/// Stop following another user in a call.
2533async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2534    let room_id = RoomId::from_proto(request.room_id);
2535    let project_id = request.project_id.map(ProjectId::from_proto);
2536    let leader_id = request.leader_id.context("invalid leader id")?.into();
2537    let follower_id = session.connection_id;
2538
2539    session
2540        .db()
2541        .await
2542        .check_room_participants(room_id, leader_id, session.connection_id)
2543        .await?;
2544
2545    session
2546        .peer
2547        .forward_send(session.connection_id, leader_id, request)?;
2548
2549    if let Some(project_id) = project_id {
2550        let room = session
2551            .db()
2552            .await
2553            .unfollow(room_id, project_id, leader_id, follower_id)
2554            .await?;
2555        room_updated(&room, &session.peer);
2556    }
2557
2558    Ok(())
2559}
2560
2561/// Notify everyone following you of your current location.
2562async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2563    let room_id = RoomId::from_proto(request.room_id);
2564    let database = session.db.lock().await;
2565
2566    let connection_ids = if let Some(project_id) = request.project_id {
2567        let project_id = ProjectId::from_proto(project_id);
2568        database
2569            .project_connection_ids(project_id, session.connection_id, true)
2570            .await?
2571    } else {
2572        database
2573            .room_connection_ids(room_id, session.connection_id)
2574            .await?
2575    };
2576
2577    // For now, don't send view update messages back to that view's current leader.
2578    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2579        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2580        _ => None,
2581    });
2582
2583    for connection_id in connection_ids.iter().cloned() {
2584        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2585            session
2586                .peer
2587                .forward_send(session.connection_id, connection_id, request.clone())?;
2588        }
2589    }
2590    Ok(())
2591}
2592
2593/// Get public data about users.
2594async fn get_users(
2595    request: proto::GetUsers,
2596    response: Response<proto::GetUsers>,
2597    session: Session,
2598) -> Result<()> {
2599    let user_ids = request
2600        .user_ids
2601        .into_iter()
2602        .map(UserId::from_proto)
2603        .collect();
2604    let users = session
2605        .db()
2606        .await
2607        .get_users_by_ids(user_ids)
2608        .await?
2609        .into_iter()
2610        .map(|user| proto::User {
2611            id: user.id.to_proto(),
2612            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2613            github_login: user.github_login,
2614            name: user.name,
2615        })
2616        .collect();
2617    response.send(proto::UsersResponse { users })?;
2618    Ok(())
2619}
2620
2621/// Search for users (to invite) buy Github login
2622async fn fuzzy_search_users(
2623    request: proto::FuzzySearchUsers,
2624    response: Response<proto::FuzzySearchUsers>,
2625    session: Session,
2626) -> Result<()> {
2627    let query = request.query;
2628    let users = match query.len() {
2629        0 => vec![],
2630        1 | 2 => session
2631            .db()
2632            .await
2633            .get_user_by_github_login(&query)
2634            .await?
2635            .into_iter()
2636            .collect(),
2637        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2638    };
2639    let users = users
2640        .into_iter()
2641        .filter(|user| user.id != session.user_id())
2642        .map(|user| proto::User {
2643            id: user.id.to_proto(),
2644            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2645            github_login: user.github_login,
2646            name: user.name,
2647        })
2648        .collect();
2649    response.send(proto::UsersResponse { users })?;
2650    Ok(())
2651}
2652
2653/// Send a contact request to another user.
2654async fn request_contact(
2655    request: proto::RequestContact,
2656    response: Response<proto::RequestContact>,
2657    session: Session,
2658) -> Result<()> {
2659    let requester_id = session.user_id();
2660    let responder_id = UserId::from_proto(request.responder_id);
2661    if requester_id == responder_id {
2662        return Err(anyhow!("cannot add yourself as a contact"))?;
2663    }
2664
2665    let notifications = session
2666        .db()
2667        .await
2668        .send_contact_request(requester_id, responder_id)
2669        .await?;
2670
2671    // Update outgoing contact requests of requester
2672    let mut update = proto::UpdateContacts::default();
2673    update.outgoing_requests.push(responder_id.to_proto());
2674    for connection_id in session
2675        .connection_pool()
2676        .await
2677        .user_connection_ids(requester_id)
2678    {
2679        session.peer.send(connection_id, update.clone())?;
2680    }
2681
2682    // Update incoming contact requests of responder
2683    let mut update = proto::UpdateContacts::default();
2684    update
2685        .incoming_requests
2686        .push(proto::IncomingContactRequest {
2687            requester_id: requester_id.to_proto(),
2688        });
2689    let connection_pool = session.connection_pool().await;
2690    for connection_id in connection_pool.user_connection_ids(responder_id) {
2691        session.peer.send(connection_id, update.clone())?;
2692    }
2693
2694    send_notifications(&connection_pool, &session.peer, notifications);
2695
2696    response.send(proto::Ack {})?;
2697    Ok(())
2698}
2699
2700/// Accept or decline a contact request
2701async fn respond_to_contact_request(
2702    request: proto::RespondToContactRequest,
2703    response: Response<proto::RespondToContactRequest>,
2704    session: Session,
2705) -> Result<()> {
2706    let responder_id = session.user_id();
2707    let requester_id = UserId::from_proto(request.requester_id);
2708    let db = session.db().await;
2709    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2710        db.dismiss_contact_notification(responder_id, requester_id)
2711            .await?;
2712    } else {
2713        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2714
2715        let notifications = db
2716            .respond_to_contact_request(responder_id, requester_id, accept)
2717            .await?;
2718        let requester_busy = db.is_user_busy(requester_id).await?;
2719        let responder_busy = db.is_user_busy(responder_id).await?;
2720
2721        let pool = session.connection_pool().await;
2722        // Update responder with new contact
2723        let mut update = proto::UpdateContacts::default();
2724        if accept {
2725            update
2726                .contacts
2727                .push(contact_for_user(requester_id, requester_busy, &pool));
2728        }
2729        update
2730            .remove_incoming_requests
2731            .push(requester_id.to_proto());
2732        for connection_id in pool.user_connection_ids(responder_id) {
2733            session.peer.send(connection_id, update.clone())?;
2734        }
2735
2736        // Update requester with new contact
2737        let mut update = proto::UpdateContacts::default();
2738        if accept {
2739            update
2740                .contacts
2741                .push(contact_for_user(responder_id, responder_busy, &pool));
2742        }
2743        update
2744            .remove_outgoing_requests
2745            .push(responder_id.to_proto());
2746
2747        for connection_id in pool.user_connection_ids(requester_id) {
2748            session.peer.send(connection_id, update.clone())?;
2749        }
2750
2751        send_notifications(&pool, &session.peer, notifications);
2752    }
2753
2754    response.send(proto::Ack {})?;
2755    Ok(())
2756}
2757
2758/// Remove a contact.
2759async fn remove_contact(
2760    request: proto::RemoveContact,
2761    response: Response<proto::RemoveContact>,
2762    session: Session,
2763) -> Result<()> {
2764    let requester_id = session.user_id();
2765    let responder_id = UserId::from_proto(request.user_id);
2766    let db = session.db().await;
2767    let (contact_accepted, deleted_notification_id) =
2768        db.remove_contact(requester_id, responder_id).await?;
2769
2770    let pool = session.connection_pool().await;
2771    // Update outgoing contact requests of requester
2772    let mut update = proto::UpdateContacts::default();
2773    if contact_accepted {
2774        update.remove_contacts.push(responder_id.to_proto());
2775    } else {
2776        update
2777            .remove_outgoing_requests
2778            .push(responder_id.to_proto());
2779    }
2780    for connection_id in pool.user_connection_ids(requester_id) {
2781        session.peer.send(connection_id, update.clone())?;
2782    }
2783
2784    // Update incoming contact requests of responder
2785    let mut update = proto::UpdateContacts::default();
2786    if contact_accepted {
2787        update.remove_contacts.push(requester_id.to_proto());
2788    } else {
2789        update
2790            .remove_incoming_requests
2791            .push(requester_id.to_proto());
2792    }
2793    for connection_id in pool.user_connection_ids(responder_id) {
2794        session.peer.send(connection_id, update.clone())?;
2795        if let Some(notification_id) = deleted_notification_id {
2796            session.peer.send(
2797                connection_id,
2798                proto::DeleteNotification {
2799                    notification_id: notification_id.to_proto(),
2800                },
2801            )?;
2802        }
2803    }
2804
2805    response.send(proto::Ack {})?;
2806    Ok(())
2807}
2808
2809fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2810    version.0.minor() < 139
2811}
2812
2813async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2814    if is_staff {
2815        return Ok(proto::Plan::ZedPro);
2816    }
2817
2818    let subscription = db.get_active_billing_subscription(user_id).await?;
2819    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2820
2821    let plan = if let Some(subscription_kind) = subscription_kind {
2822        match subscription_kind {
2823            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2824            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2825            SubscriptionKind::ZedFree => proto::Plan::Free,
2826        }
2827    } else {
2828        proto::Plan::Free
2829    };
2830
2831    Ok(plan)
2832}
2833
2834async fn make_update_user_plan_message(
2835    user: &User,
2836    is_staff: bool,
2837    db: &Arc<Database>,
2838    llm_db: Option<Arc<LlmDatabase>>,
2839) -> Result<proto::UpdateUserPlan> {
2840    let feature_flags = db.get_user_flags(user.id).await?;
2841    let plan = current_plan(db, user.id, is_staff).await?;
2842    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2843    let billing_preferences = db.get_billing_preferences(user.id).await?;
2844
2845    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2846        let subscription = db.get_active_billing_subscription(user.id).await?;
2847
2848        let subscription_period =
2849            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2850
2851        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2852            llm_db
2853                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2854                .await?
2855        } else {
2856            None
2857        };
2858
2859        (subscription_period, usage)
2860    } else {
2861        (None, None)
2862    };
2863
2864    let bypass_account_age_check = feature_flags
2865        .iter()
2866        .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2867    let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2868        && !bypass_account_age_check
2869        && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2870
2871    Ok(proto::UpdateUserPlan {
2872        plan: plan.into(),
2873        trial_started_at: billing_customer
2874            .as_ref()
2875            .and_then(|billing_customer| billing_customer.trial_started_at)
2876            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2877        is_usage_based_billing_enabled: if is_staff {
2878            Some(true)
2879        } else {
2880            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2881        },
2882        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2883            proto::SubscriptionPeriod {
2884                started_at: started_at.timestamp() as u64,
2885                ended_at: ended_at.timestamp() as u64,
2886            }
2887        }),
2888        account_too_young: Some(account_too_young),
2889        has_overdue_invoices: billing_customer
2890            .map(|billing_customer| billing_customer.has_overdue_invoices),
2891        usage: Some(
2892            usage
2893                .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2894                .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2895        ),
2896    })
2897}
2898
2899fn model_requests_limit(
2900    plan: cloud_llm_client::Plan,
2901    feature_flags: &Vec<String>,
2902) -> cloud_llm_client::UsageLimit {
2903    match plan.model_requests_limit() {
2904        cloud_llm_client::UsageLimit::Limited(limit) => {
2905            let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2906                && feature_flags
2907                    .iter()
2908                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2909            {
2910                1_000
2911            } else {
2912                limit
2913            };
2914
2915            cloud_llm_client::UsageLimit::Limited(limit)
2916        }
2917        cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2918    }
2919}
2920
2921fn subscription_usage_to_proto(
2922    plan: proto::Plan,
2923    usage: crate::llm::db::subscription_usage::Model,
2924    feature_flags: &Vec<String>,
2925) -> proto::SubscriptionUsage {
2926    let plan = match plan {
2927        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2928        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2929        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2930    };
2931
2932    proto::SubscriptionUsage {
2933        model_requests_usage_amount: usage.model_requests as u32,
2934        model_requests_usage_limit: Some(proto::UsageLimit {
2935            variant: Some(match model_requests_limit(plan, feature_flags) {
2936                cloud_llm_client::UsageLimit::Limited(limit) => {
2937                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2938                        limit: limit as u32,
2939                    })
2940                }
2941                cloud_llm_client::UsageLimit::Unlimited => {
2942                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2943                }
2944            }),
2945        }),
2946        edit_predictions_usage_amount: usage.edit_predictions as u32,
2947        edit_predictions_usage_limit: Some(proto::UsageLimit {
2948            variant: Some(match plan.edit_predictions_limit() {
2949                cloud_llm_client::UsageLimit::Limited(limit) => {
2950                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2951                        limit: limit as u32,
2952                    })
2953                }
2954                cloud_llm_client::UsageLimit::Unlimited => {
2955                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2956                }
2957            }),
2958        }),
2959    }
2960}
2961
2962fn make_default_subscription_usage(
2963    plan: proto::Plan,
2964    feature_flags: &Vec<String>,
2965) -> proto::SubscriptionUsage {
2966    let plan = match plan {
2967        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2968        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2969        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2970    };
2971
2972    proto::SubscriptionUsage {
2973        model_requests_usage_amount: 0,
2974        model_requests_usage_limit: Some(proto::UsageLimit {
2975            variant: Some(match model_requests_limit(plan, feature_flags) {
2976                cloud_llm_client::UsageLimit::Limited(limit) => {
2977                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2978                        limit: limit as u32,
2979                    })
2980                }
2981                cloud_llm_client::UsageLimit::Unlimited => {
2982                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2983                }
2984            }),
2985        }),
2986        edit_predictions_usage_amount: 0,
2987        edit_predictions_usage_limit: Some(proto::UsageLimit {
2988            variant: Some(match plan.edit_predictions_limit() {
2989                cloud_llm_client::UsageLimit::Limited(limit) => {
2990                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2991                        limit: limit as u32,
2992                    })
2993                }
2994                cloud_llm_client::UsageLimit::Unlimited => {
2995                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2996                }
2997            }),
2998        }),
2999    }
3000}
3001
3002async fn update_user_plan(session: &Session) -> Result<()> {
3003    let db = session.db().await;
3004
3005    let update_user_plan = make_update_user_plan_message(
3006        session.principal.user(),
3007        session.is_staff(),
3008        &db.0,
3009        session.app_state.llm_db.clone(),
3010    )
3011    .await?;
3012
3013    session
3014        .peer
3015        .send(session.connection_id, update_user_plan)
3016        .trace_err();
3017
3018    Ok(())
3019}
3020
3021async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3022    subscribe_user_to_channels(session.user_id(), &session).await?;
3023    Ok(())
3024}
3025
3026async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3027    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3028    let mut pool = session.connection_pool().await;
3029    for membership in &channels_for_user.channel_memberships {
3030        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3031    }
3032    session.peer.send(
3033        session.connection_id,
3034        build_update_user_channels(&channels_for_user),
3035    )?;
3036    session.peer.send(
3037        session.connection_id,
3038        build_channels_update(channels_for_user),
3039    )?;
3040    Ok(())
3041}
3042
3043/// Creates a new channel.
3044async fn create_channel(
3045    request: proto::CreateChannel,
3046    response: Response<proto::CreateChannel>,
3047    session: Session,
3048) -> Result<()> {
3049    let db = session.db().await;
3050
3051    let parent_id = request.parent_id.map(ChannelId::from_proto);
3052    let (channel, membership) = db
3053        .create_channel(&request.name, parent_id, session.user_id())
3054        .await?;
3055
3056    let root_id = channel.root_id();
3057    let channel = Channel::from_model(channel);
3058
3059    response.send(proto::CreateChannelResponse {
3060        channel: Some(channel.to_proto()),
3061        parent_id: request.parent_id,
3062    })?;
3063
3064    let mut connection_pool = session.connection_pool().await;
3065    if let Some(membership) = membership {
3066        connection_pool.subscribe_to_channel(
3067            membership.user_id,
3068            membership.channel_id,
3069            membership.role,
3070        );
3071        let update = proto::UpdateUserChannels {
3072            channel_memberships: vec![proto::ChannelMembership {
3073                channel_id: membership.channel_id.to_proto(),
3074                role: membership.role.into(),
3075            }],
3076            ..Default::default()
3077        };
3078        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3079            session.peer.send(connection_id, update.clone())?;
3080        }
3081    }
3082
3083    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3084        if !role.can_see_channel(channel.visibility) {
3085            continue;
3086        }
3087
3088        let update = proto::UpdateChannels {
3089            channels: vec![channel.to_proto()],
3090            ..Default::default()
3091        };
3092        session.peer.send(connection_id, update.clone())?;
3093    }
3094
3095    Ok(())
3096}
3097
3098/// Delete a channel
3099async fn delete_channel(
3100    request: proto::DeleteChannel,
3101    response: Response<proto::DeleteChannel>,
3102    session: Session,
3103) -> Result<()> {
3104    let db = session.db().await;
3105
3106    let channel_id = request.channel_id;
3107    let (root_channel, removed_channels) = db
3108        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3109        .await?;
3110    response.send(proto::Ack {})?;
3111
3112    // Notify members of removed channels
3113    let mut update = proto::UpdateChannels::default();
3114    update
3115        .delete_channels
3116        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3117
3118    let connection_pool = session.connection_pool().await;
3119    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3120        session.peer.send(connection_id, update.clone())?;
3121    }
3122
3123    Ok(())
3124}
3125
3126/// Invite someone to join a channel.
3127async fn invite_channel_member(
3128    request: proto::InviteChannelMember,
3129    response: Response<proto::InviteChannelMember>,
3130    session: Session,
3131) -> Result<()> {
3132    let db = session.db().await;
3133    let channel_id = ChannelId::from_proto(request.channel_id);
3134    let invitee_id = UserId::from_proto(request.user_id);
3135    let InviteMemberResult {
3136        channel,
3137        notifications,
3138    } = db
3139        .invite_channel_member(
3140            channel_id,
3141            invitee_id,
3142            session.user_id(),
3143            request.role().into(),
3144        )
3145        .await?;
3146
3147    let update = proto::UpdateChannels {
3148        channel_invitations: vec![channel.to_proto()],
3149        ..Default::default()
3150    };
3151
3152    let connection_pool = session.connection_pool().await;
3153    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3154        session.peer.send(connection_id, update.clone())?;
3155    }
3156
3157    send_notifications(&connection_pool, &session.peer, notifications);
3158
3159    response.send(proto::Ack {})?;
3160    Ok(())
3161}
3162
3163/// remove someone from a channel
3164async fn remove_channel_member(
3165    request: proto::RemoveChannelMember,
3166    response: Response<proto::RemoveChannelMember>,
3167    session: Session,
3168) -> Result<()> {
3169    let db = session.db().await;
3170    let channel_id = ChannelId::from_proto(request.channel_id);
3171    let member_id = UserId::from_proto(request.user_id);
3172
3173    let RemoveChannelMemberResult {
3174        membership_update,
3175        notification_id,
3176    } = db
3177        .remove_channel_member(channel_id, member_id, session.user_id())
3178        .await?;
3179
3180    let mut connection_pool = session.connection_pool().await;
3181    notify_membership_updated(
3182        &mut connection_pool,
3183        membership_update,
3184        member_id,
3185        &session.peer,
3186    );
3187    for connection_id in connection_pool.user_connection_ids(member_id) {
3188        if let Some(notification_id) = notification_id {
3189            session
3190                .peer
3191                .send(
3192                    connection_id,
3193                    proto::DeleteNotification {
3194                        notification_id: notification_id.to_proto(),
3195                    },
3196                )
3197                .trace_err();
3198        }
3199    }
3200
3201    response.send(proto::Ack {})?;
3202    Ok(())
3203}
3204
3205/// Toggle the channel between public and private.
3206/// Care is taken to maintain the invariant that public channels only descend from public channels,
3207/// (though members-only channels can appear at any point in the hierarchy).
3208async fn set_channel_visibility(
3209    request: proto::SetChannelVisibility,
3210    response: Response<proto::SetChannelVisibility>,
3211    session: Session,
3212) -> Result<()> {
3213    let db = session.db().await;
3214    let channel_id = ChannelId::from_proto(request.channel_id);
3215    let visibility = request.visibility().into();
3216
3217    let channel_model = db
3218        .set_channel_visibility(channel_id, visibility, session.user_id())
3219        .await?;
3220    let root_id = channel_model.root_id();
3221    let channel = Channel::from_model(channel_model);
3222
3223    let mut connection_pool = session.connection_pool().await;
3224    for (user_id, role) in connection_pool
3225        .channel_user_ids(root_id)
3226        .collect::<Vec<_>>()
3227        .into_iter()
3228    {
3229        let update = if role.can_see_channel(channel.visibility) {
3230            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3231            proto::UpdateChannels {
3232                channels: vec![channel.to_proto()],
3233                ..Default::default()
3234            }
3235        } else {
3236            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3237            proto::UpdateChannels {
3238                delete_channels: vec![channel.id.to_proto()],
3239                ..Default::default()
3240            }
3241        };
3242
3243        for connection_id in connection_pool.user_connection_ids(user_id) {
3244            session.peer.send(connection_id, update.clone())?;
3245        }
3246    }
3247
3248    response.send(proto::Ack {})?;
3249    Ok(())
3250}
3251
3252/// Alter the role for a user in the channel.
3253async fn set_channel_member_role(
3254    request: proto::SetChannelMemberRole,
3255    response: Response<proto::SetChannelMemberRole>,
3256    session: Session,
3257) -> Result<()> {
3258    let db = session.db().await;
3259    let channel_id = ChannelId::from_proto(request.channel_id);
3260    let member_id = UserId::from_proto(request.user_id);
3261    let result = db
3262        .set_channel_member_role(
3263            channel_id,
3264            session.user_id(),
3265            member_id,
3266            request.role().into(),
3267        )
3268        .await?;
3269
3270    match result {
3271        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3272            let mut connection_pool = session.connection_pool().await;
3273            notify_membership_updated(
3274                &mut connection_pool,
3275                membership_update,
3276                member_id,
3277                &session.peer,
3278            )
3279        }
3280        db::SetMemberRoleResult::InviteUpdated(channel) => {
3281            let update = proto::UpdateChannels {
3282                channel_invitations: vec![channel.to_proto()],
3283                ..Default::default()
3284            };
3285
3286            for connection_id in session
3287                .connection_pool()
3288                .await
3289                .user_connection_ids(member_id)
3290            {
3291                session.peer.send(connection_id, update.clone())?;
3292            }
3293        }
3294    }
3295
3296    response.send(proto::Ack {})?;
3297    Ok(())
3298}
3299
3300/// Change the name of a channel
3301async fn rename_channel(
3302    request: proto::RenameChannel,
3303    response: Response<proto::RenameChannel>,
3304    session: Session,
3305) -> Result<()> {
3306    let db = session.db().await;
3307    let channel_id = ChannelId::from_proto(request.channel_id);
3308    let channel_model = db
3309        .rename_channel(channel_id, session.user_id(), &request.name)
3310        .await?;
3311    let root_id = channel_model.root_id();
3312    let channel = Channel::from_model(channel_model);
3313
3314    response.send(proto::RenameChannelResponse {
3315        channel: Some(channel.to_proto()),
3316    })?;
3317
3318    let connection_pool = session.connection_pool().await;
3319    let update = proto::UpdateChannels {
3320        channels: vec![channel.to_proto()],
3321        ..Default::default()
3322    };
3323    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3324        if role.can_see_channel(channel.visibility) {
3325            session.peer.send(connection_id, update.clone())?;
3326        }
3327    }
3328
3329    Ok(())
3330}
3331
3332/// Move a channel to a new parent.
3333async fn move_channel(
3334    request: proto::MoveChannel,
3335    response: Response<proto::MoveChannel>,
3336    session: Session,
3337) -> Result<()> {
3338    let channel_id = ChannelId::from_proto(request.channel_id);
3339    let to = ChannelId::from_proto(request.to);
3340
3341    let (root_id, channels) = session
3342        .db()
3343        .await
3344        .move_channel(channel_id, to, session.user_id())
3345        .await?;
3346
3347    let connection_pool = session.connection_pool().await;
3348    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3349        let channels = channels
3350            .iter()
3351            .filter_map(|channel| {
3352                if role.can_see_channel(channel.visibility) {
3353                    Some(channel.to_proto())
3354                } else {
3355                    None
3356                }
3357            })
3358            .collect::<Vec<_>>();
3359        if channels.is_empty() {
3360            continue;
3361        }
3362
3363        let update = proto::UpdateChannels {
3364            channels,
3365            ..Default::default()
3366        };
3367
3368        session.peer.send(connection_id, update.clone())?;
3369    }
3370
3371    response.send(Ack {})?;
3372    Ok(())
3373}
3374
3375async fn reorder_channel(
3376    request: proto::ReorderChannel,
3377    response: Response<proto::ReorderChannel>,
3378    session: Session,
3379) -> Result<()> {
3380    let channel_id = ChannelId::from_proto(request.channel_id);
3381    let direction = request.direction();
3382
3383    let updated_channels = session
3384        .db()
3385        .await
3386        .reorder_channel(channel_id, direction, session.user_id())
3387        .await?;
3388
3389    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3390        let connection_pool = session.connection_pool().await;
3391        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3392            let channels = updated_channels
3393                .iter()
3394                .filter_map(|channel| {
3395                    if role.can_see_channel(channel.visibility) {
3396                        Some(channel.to_proto())
3397                    } else {
3398                        None
3399                    }
3400                })
3401                .collect::<Vec<_>>();
3402
3403            if channels.is_empty() {
3404                continue;
3405            }
3406
3407            let update = proto::UpdateChannels {
3408                channels,
3409                ..Default::default()
3410            };
3411
3412            session.peer.send(connection_id, update.clone())?;
3413        }
3414    }
3415
3416    response.send(Ack {})?;
3417    Ok(())
3418}
3419
3420/// Get the list of channel members
3421async fn get_channel_members(
3422    request: proto::GetChannelMembers,
3423    response: Response<proto::GetChannelMembers>,
3424    session: Session,
3425) -> Result<()> {
3426    let db = session.db().await;
3427    let channel_id = ChannelId::from_proto(request.channel_id);
3428    let limit = if request.limit == 0 {
3429        u16::MAX as u64
3430    } else {
3431        request.limit
3432    };
3433    let (members, users) = db
3434        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3435        .await?;
3436    response.send(proto::GetChannelMembersResponse { members, users })?;
3437    Ok(())
3438}
3439
3440/// Accept or decline a channel invitation.
3441async fn respond_to_channel_invite(
3442    request: proto::RespondToChannelInvite,
3443    response: Response<proto::RespondToChannelInvite>,
3444    session: Session,
3445) -> Result<()> {
3446    let db = session.db().await;
3447    let channel_id = ChannelId::from_proto(request.channel_id);
3448    let RespondToChannelInvite {
3449        membership_update,
3450        notifications,
3451    } = db
3452        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3453        .await?;
3454
3455    let mut connection_pool = session.connection_pool().await;
3456    if let Some(membership_update) = membership_update {
3457        notify_membership_updated(
3458            &mut connection_pool,
3459            membership_update,
3460            session.user_id(),
3461            &session.peer,
3462        );
3463    } else {
3464        let update = proto::UpdateChannels {
3465            remove_channel_invitations: vec![channel_id.to_proto()],
3466            ..Default::default()
3467        };
3468
3469        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3470            session.peer.send(connection_id, update.clone())?;
3471        }
3472    };
3473
3474    send_notifications(&connection_pool, &session.peer, notifications);
3475
3476    response.send(proto::Ack {})?;
3477
3478    Ok(())
3479}
3480
3481/// Join the channels' room
3482async fn join_channel(
3483    request: proto::JoinChannel,
3484    response: Response<proto::JoinChannel>,
3485    session: Session,
3486) -> Result<()> {
3487    let channel_id = ChannelId::from_proto(request.channel_id);
3488    join_channel_internal(channel_id, Box::new(response), session).await
3489}
3490
3491trait JoinChannelInternalResponse {
3492    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3493}
3494impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3495    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3496        Response::<proto::JoinChannel>::send(self, result)
3497    }
3498}
3499impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3500    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3501        Response::<proto::JoinRoom>::send(self, result)
3502    }
3503}
3504
3505async fn join_channel_internal(
3506    channel_id: ChannelId,
3507    response: Box<impl JoinChannelInternalResponse>,
3508    session: Session,
3509) -> Result<()> {
3510    let joined_room = {
3511        let mut db = session.db().await;
3512        // If zed quits without leaving the room, and the user re-opens zed before the
3513        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3514        // room they were in.
3515        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3516            tracing::info!(
3517                stale_connection_id = %connection,
3518                "cleaning up stale connection",
3519            );
3520            drop(db);
3521            leave_room_for_session(&session, connection).await?;
3522            db = session.db().await;
3523        }
3524
3525        let (joined_room, membership_updated, role) = db
3526            .join_channel(channel_id, session.user_id(), session.connection_id)
3527            .await?;
3528
3529        let live_kit_connection_info =
3530            session
3531                .app_state
3532                .livekit_client
3533                .as_ref()
3534                .and_then(|live_kit| {
3535                    let (can_publish, token) = if role == ChannelRole::Guest {
3536                        (
3537                            false,
3538                            live_kit
3539                                .guest_token(
3540                                    &joined_room.room.livekit_room,
3541                                    &session.user_id().to_string(),
3542                                )
3543                                .trace_err()?,
3544                        )
3545                    } else {
3546                        (
3547                            true,
3548                            live_kit
3549                                .room_token(
3550                                    &joined_room.room.livekit_room,
3551                                    &session.user_id().to_string(),
3552                                )
3553                                .trace_err()?,
3554                        )
3555                    };
3556
3557                    Some(LiveKitConnectionInfo {
3558                        server_url: live_kit.url().into(),
3559                        token,
3560                        can_publish,
3561                    })
3562                });
3563
3564        response.send(proto::JoinRoomResponse {
3565            room: Some(joined_room.room.clone()),
3566            channel_id: joined_room
3567                .channel
3568                .as_ref()
3569                .map(|channel| channel.id.to_proto()),
3570            live_kit_connection_info,
3571        })?;
3572
3573        let mut connection_pool = session.connection_pool().await;
3574        if let Some(membership_updated) = membership_updated {
3575            notify_membership_updated(
3576                &mut connection_pool,
3577                membership_updated,
3578                session.user_id(),
3579                &session.peer,
3580            );
3581        }
3582
3583        room_updated(&joined_room.room, &session.peer);
3584
3585        joined_room
3586    };
3587
3588    channel_updated(
3589        &joined_room.channel.context("channel not returned")?,
3590        &joined_room.room,
3591        &session.peer,
3592        &*session.connection_pool().await,
3593    );
3594
3595    update_user_contacts(session.user_id(), &session).await?;
3596    Ok(())
3597}
3598
3599/// Start editing the channel notes
3600async fn join_channel_buffer(
3601    request: proto::JoinChannelBuffer,
3602    response: Response<proto::JoinChannelBuffer>,
3603    session: Session,
3604) -> Result<()> {
3605    let db = session.db().await;
3606    let channel_id = ChannelId::from_proto(request.channel_id);
3607
3608    let open_response = db
3609        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3610        .await?;
3611
3612    let collaborators = open_response.collaborators.clone();
3613    response.send(open_response)?;
3614
3615    let update = UpdateChannelBufferCollaborators {
3616        channel_id: channel_id.to_proto(),
3617        collaborators: collaborators.clone(),
3618    };
3619    channel_buffer_updated(
3620        session.connection_id,
3621        collaborators
3622            .iter()
3623            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3624        &update,
3625        &session.peer,
3626    );
3627
3628    Ok(())
3629}
3630
3631/// Edit the channel notes
3632async fn update_channel_buffer(
3633    request: proto::UpdateChannelBuffer,
3634    session: Session,
3635) -> Result<()> {
3636    let db = session.db().await;
3637    let channel_id = ChannelId::from_proto(request.channel_id);
3638
3639    let (collaborators, epoch, version) = db
3640        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3641        .await?;
3642
3643    channel_buffer_updated(
3644        session.connection_id,
3645        collaborators.clone(),
3646        &proto::UpdateChannelBuffer {
3647            channel_id: channel_id.to_proto(),
3648            operations: request.operations,
3649        },
3650        &session.peer,
3651    );
3652
3653    let pool = &*session.connection_pool().await;
3654
3655    let non_collaborators =
3656        pool.channel_connection_ids(channel_id)
3657            .filter_map(|(connection_id, _)| {
3658                if collaborators.contains(&connection_id) {
3659                    None
3660                } else {
3661                    Some(connection_id)
3662                }
3663            });
3664
3665    broadcast(None, non_collaborators, |peer_id| {
3666        session.peer.send(
3667            peer_id,
3668            proto::UpdateChannels {
3669                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3670                    channel_id: channel_id.to_proto(),
3671                    epoch: epoch as u64,
3672                    version: version.clone(),
3673                }],
3674                ..Default::default()
3675            },
3676        )
3677    });
3678
3679    Ok(())
3680}
3681
3682/// Rejoin the channel notes after a connection blip
3683async fn rejoin_channel_buffers(
3684    request: proto::RejoinChannelBuffers,
3685    response: Response<proto::RejoinChannelBuffers>,
3686    session: Session,
3687) -> Result<()> {
3688    let db = session.db().await;
3689    let buffers = db
3690        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3691        .await?;
3692
3693    for rejoined_buffer in &buffers {
3694        let collaborators_to_notify = rejoined_buffer
3695            .buffer
3696            .collaborators
3697            .iter()
3698            .filter_map(|c| Some(c.peer_id?.into()));
3699        channel_buffer_updated(
3700            session.connection_id,
3701            collaborators_to_notify,
3702            &proto::UpdateChannelBufferCollaborators {
3703                channel_id: rejoined_buffer.buffer.channel_id,
3704                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3705            },
3706            &session.peer,
3707        );
3708    }
3709
3710    response.send(proto::RejoinChannelBuffersResponse {
3711        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3712    })?;
3713
3714    Ok(())
3715}
3716
3717/// Stop editing the channel notes
3718async fn leave_channel_buffer(
3719    request: proto::LeaveChannelBuffer,
3720    response: Response<proto::LeaveChannelBuffer>,
3721    session: Session,
3722) -> Result<()> {
3723    let db = session.db().await;
3724    let channel_id = ChannelId::from_proto(request.channel_id);
3725
3726    let left_buffer = db
3727        .leave_channel_buffer(channel_id, session.connection_id)
3728        .await?;
3729
3730    response.send(Ack {})?;
3731
3732    channel_buffer_updated(
3733        session.connection_id,
3734        left_buffer.connections,
3735        &proto::UpdateChannelBufferCollaborators {
3736            channel_id: channel_id.to_proto(),
3737            collaborators: left_buffer.collaborators,
3738        },
3739        &session.peer,
3740    );
3741
3742    Ok(())
3743}
3744
3745fn channel_buffer_updated<T: EnvelopedMessage>(
3746    sender_id: ConnectionId,
3747    collaborators: impl IntoIterator<Item = ConnectionId>,
3748    message: &T,
3749    peer: &Peer,
3750) {
3751    broadcast(Some(sender_id), collaborators, |peer_id| {
3752        peer.send(peer_id, message.clone())
3753    });
3754}
3755
3756fn send_notifications(
3757    connection_pool: &ConnectionPool,
3758    peer: &Peer,
3759    notifications: db::NotificationBatch,
3760) {
3761    for (user_id, notification) in notifications {
3762        for connection_id in connection_pool.user_connection_ids(user_id) {
3763            if let Err(error) = peer.send(
3764                connection_id,
3765                proto::AddNotification {
3766                    notification: Some(notification.clone()),
3767                },
3768            ) {
3769                tracing::error!(
3770                    "failed to send notification to {:?} {}",
3771                    connection_id,
3772                    error
3773                );
3774            }
3775        }
3776    }
3777}
3778
3779/// Send a message to the channel
3780async fn send_channel_message(
3781    request: proto::SendChannelMessage,
3782    response: Response<proto::SendChannelMessage>,
3783    session: Session,
3784) -> Result<()> {
3785    // Validate the message body.
3786    let body = request.body.trim().to_string();
3787    if body.len() > MAX_MESSAGE_LEN {
3788        return Err(anyhow!("message is too long"))?;
3789    }
3790    if body.is_empty() {
3791        return Err(anyhow!("message can't be blank"))?;
3792    }
3793
3794    // TODO: adjust mentions if body is trimmed
3795
3796    let timestamp = OffsetDateTime::now_utc();
3797    let nonce = request.nonce.context("nonce can't be blank")?;
3798
3799    let channel_id = ChannelId::from_proto(request.channel_id);
3800    let CreatedChannelMessage {
3801        message_id,
3802        participant_connection_ids,
3803        notifications,
3804    } = session
3805        .db()
3806        .await
3807        .create_channel_message(
3808            channel_id,
3809            session.user_id(),
3810            &body,
3811            &request.mentions,
3812            timestamp,
3813            nonce.clone().into(),
3814            request.reply_to_message_id.map(MessageId::from_proto),
3815        )
3816        .await?;
3817
3818    let message = proto::ChannelMessage {
3819        sender_id: session.user_id().to_proto(),
3820        id: message_id.to_proto(),
3821        body,
3822        mentions: request.mentions,
3823        timestamp: timestamp.unix_timestamp() as u64,
3824        nonce: Some(nonce),
3825        reply_to_message_id: request.reply_to_message_id,
3826        edited_at: None,
3827    };
3828    broadcast(
3829        Some(session.connection_id),
3830        participant_connection_ids.clone(),
3831        |connection| {
3832            session.peer.send(
3833                connection,
3834                proto::ChannelMessageSent {
3835                    channel_id: channel_id.to_proto(),
3836                    message: Some(message.clone()),
3837                },
3838            )
3839        },
3840    );
3841    response.send(proto::SendChannelMessageResponse {
3842        message: Some(message),
3843    })?;
3844
3845    let pool = &*session.connection_pool().await;
3846    let non_participants =
3847        pool.channel_connection_ids(channel_id)
3848            .filter_map(|(connection_id, _)| {
3849                if participant_connection_ids.contains(&connection_id) {
3850                    None
3851                } else {
3852                    Some(connection_id)
3853                }
3854            });
3855    broadcast(None, non_participants, |peer_id| {
3856        session.peer.send(
3857            peer_id,
3858            proto::UpdateChannels {
3859                latest_channel_message_ids: vec![proto::ChannelMessageId {
3860                    channel_id: channel_id.to_proto(),
3861                    message_id: message_id.to_proto(),
3862                }],
3863                ..Default::default()
3864            },
3865        )
3866    });
3867    send_notifications(pool, &session.peer, notifications);
3868
3869    Ok(())
3870}
3871
3872/// Delete a channel message
3873async fn remove_channel_message(
3874    request: proto::RemoveChannelMessage,
3875    response: Response<proto::RemoveChannelMessage>,
3876    session: Session,
3877) -> Result<()> {
3878    let channel_id = ChannelId::from_proto(request.channel_id);
3879    let message_id = MessageId::from_proto(request.message_id);
3880    let (connection_ids, existing_notification_ids) = session
3881        .db()
3882        .await
3883        .remove_channel_message(channel_id, message_id, session.user_id())
3884        .await?;
3885
3886    broadcast(
3887        Some(session.connection_id),
3888        connection_ids,
3889        move |connection| {
3890            session.peer.send(connection, request.clone())?;
3891
3892            for notification_id in &existing_notification_ids {
3893                session.peer.send(
3894                    connection,
3895                    proto::DeleteNotification {
3896                        notification_id: (*notification_id).to_proto(),
3897                    },
3898                )?;
3899            }
3900
3901            Ok(())
3902        },
3903    );
3904    response.send(proto::Ack {})?;
3905    Ok(())
3906}
3907
3908async fn update_channel_message(
3909    request: proto::UpdateChannelMessage,
3910    response: Response<proto::UpdateChannelMessage>,
3911    session: Session,
3912) -> Result<()> {
3913    let channel_id = ChannelId::from_proto(request.channel_id);
3914    let message_id = MessageId::from_proto(request.message_id);
3915    let updated_at = OffsetDateTime::now_utc();
3916    let UpdatedChannelMessage {
3917        message_id,
3918        participant_connection_ids,
3919        notifications,
3920        reply_to_message_id,
3921        timestamp,
3922        deleted_mention_notification_ids,
3923        updated_mention_notifications,
3924    } = session
3925        .db()
3926        .await
3927        .update_channel_message(
3928            channel_id,
3929            message_id,
3930            session.user_id(),
3931            request.body.as_str(),
3932            &request.mentions,
3933            updated_at,
3934        )
3935        .await?;
3936
3937    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3938
3939    let message = proto::ChannelMessage {
3940        sender_id: session.user_id().to_proto(),
3941        id: message_id.to_proto(),
3942        body: request.body.clone(),
3943        mentions: request.mentions.clone(),
3944        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3945        nonce: Some(nonce),
3946        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3947        edited_at: Some(updated_at.unix_timestamp() as u64),
3948    };
3949
3950    response.send(proto::Ack {})?;
3951
3952    let pool = &*session.connection_pool().await;
3953    broadcast(
3954        Some(session.connection_id),
3955        participant_connection_ids,
3956        |connection| {
3957            session.peer.send(
3958                connection,
3959                proto::ChannelMessageUpdate {
3960                    channel_id: channel_id.to_proto(),
3961                    message: Some(message.clone()),
3962                },
3963            )?;
3964
3965            for notification_id in &deleted_mention_notification_ids {
3966                session.peer.send(
3967                    connection,
3968                    proto::DeleteNotification {
3969                        notification_id: (*notification_id).to_proto(),
3970                    },
3971                )?;
3972            }
3973
3974            for notification in &updated_mention_notifications {
3975                session.peer.send(
3976                    connection,
3977                    proto::UpdateNotification {
3978                        notification: Some(notification.clone()),
3979                    },
3980                )?;
3981            }
3982
3983            Ok(())
3984        },
3985    );
3986
3987    send_notifications(pool, &session.peer, notifications);
3988
3989    Ok(())
3990}
3991
3992/// Mark a channel message as read
3993async fn acknowledge_channel_message(
3994    request: proto::AckChannelMessage,
3995    session: Session,
3996) -> Result<()> {
3997    let channel_id = ChannelId::from_proto(request.channel_id);
3998    let message_id = MessageId::from_proto(request.message_id);
3999    let notifications = session
4000        .db()
4001        .await
4002        .observe_channel_message(channel_id, session.user_id(), message_id)
4003        .await?;
4004    send_notifications(
4005        &*session.connection_pool().await,
4006        &session.peer,
4007        notifications,
4008    );
4009    Ok(())
4010}
4011
4012/// Mark a buffer version as synced
4013async fn acknowledge_buffer_version(
4014    request: proto::AckBufferOperation,
4015    session: Session,
4016) -> Result<()> {
4017    let buffer_id = BufferId::from_proto(request.buffer_id);
4018    session
4019        .db()
4020        .await
4021        .observe_buffer_version(
4022            buffer_id,
4023            session.user_id(),
4024            request.epoch as i32,
4025            &request.version,
4026        )
4027        .await?;
4028    Ok(())
4029}
4030
4031/// Get a Supermaven API key for the user
4032async fn get_supermaven_api_key(
4033    _request: proto::GetSupermavenApiKey,
4034    response: Response<proto::GetSupermavenApiKey>,
4035    session: Session,
4036) -> Result<()> {
4037    let user_id: String = session.user_id().to_string();
4038    if !session.is_staff() {
4039        return Err(anyhow!("supermaven not enabled for this account"))?;
4040    }
4041
4042    let email = session.email().context("user must have an email")?;
4043
4044    let supermaven_admin_api = session
4045        .supermaven_client
4046        .as_ref()
4047        .context("supermaven not configured")?;
4048
4049    let result = supermaven_admin_api
4050        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4051        .await?;
4052
4053    response.send(proto::GetSupermavenApiKeyResponse {
4054        api_key: result.api_key,
4055    })?;
4056
4057    Ok(())
4058}
4059
4060/// Start receiving chat updates for a channel
4061async fn join_channel_chat(
4062    request: proto::JoinChannelChat,
4063    response: Response<proto::JoinChannelChat>,
4064    session: Session,
4065) -> Result<()> {
4066    let channel_id = ChannelId::from_proto(request.channel_id);
4067
4068    let db = session.db().await;
4069    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4070        .await?;
4071    let messages = db
4072        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4073        .await?;
4074    response.send(proto::JoinChannelChatResponse {
4075        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4076        messages,
4077    })?;
4078    Ok(())
4079}
4080
4081/// Stop receiving chat updates for a channel
4082async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4083    let channel_id = ChannelId::from_proto(request.channel_id);
4084    session
4085        .db()
4086        .await
4087        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4088        .await?;
4089    Ok(())
4090}
4091
4092/// Retrieve the chat history for a channel
4093async fn get_channel_messages(
4094    request: proto::GetChannelMessages,
4095    response: Response<proto::GetChannelMessages>,
4096    session: Session,
4097) -> Result<()> {
4098    let channel_id = ChannelId::from_proto(request.channel_id);
4099    let messages = session
4100        .db()
4101        .await
4102        .get_channel_messages(
4103            channel_id,
4104            session.user_id(),
4105            MESSAGE_COUNT_PER_PAGE,
4106            Some(MessageId::from_proto(request.before_message_id)),
4107        )
4108        .await?;
4109    response.send(proto::GetChannelMessagesResponse {
4110        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4111        messages,
4112    })?;
4113    Ok(())
4114}
4115
4116/// Retrieve specific chat messages
4117async fn get_channel_messages_by_id(
4118    request: proto::GetChannelMessagesById,
4119    response: Response<proto::GetChannelMessagesById>,
4120    session: Session,
4121) -> Result<()> {
4122    let message_ids = request
4123        .message_ids
4124        .iter()
4125        .map(|id| MessageId::from_proto(*id))
4126        .collect::<Vec<_>>();
4127    let messages = session
4128        .db()
4129        .await
4130        .get_channel_messages_by_id(session.user_id(), &message_ids)
4131        .await?;
4132    response.send(proto::GetChannelMessagesResponse {
4133        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4134        messages,
4135    })?;
4136    Ok(())
4137}
4138
4139/// Retrieve the current users notifications
4140async fn get_notifications(
4141    request: proto::GetNotifications,
4142    response: Response<proto::GetNotifications>,
4143    session: Session,
4144) -> Result<()> {
4145    let notifications = session
4146        .db()
4147        .await
4148        .get_notifications(
4149            session.user_id(),
4150            NOTIFICATION_COUNT_PER_PAGE,
4151            request.before_id.map(db::NotificationId::from_proto),
4152        )
4153        .await?;
4154    response.send(proto::GetNotificationsResponse {
4155        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4156        notifications,
4157    })?;
4158    Ok(())
4159}
4160
4161/// Mark notifications as read
4162async fn mark_notification_as_read(
4163    request: proto::MarkNotificationRead,
4164    response: Response<proto::MarkNotificationRead>,
4165    session: Session,
4166) -> Result<()> {
4167    let database = &session.db().await;
4168    let notifications = database
4169        .mark_notification_as_read_by_id(
4170            session.user_id(),
4171            NotificationId::from_proto(request.notification_id),
4172        )
4173        .await?;
4174    send_notifications(
4175        &*session.connection_pool().await,
4176        &session.peer,
4177        notifications,
4178    );
4179    response.send(proto::Ack {})?;
4180    Ok(())
4181}
4182
4183/// Get the current users information
4184async fn get_private_user_info(
4185    _request: proto::GetPrivateUserInfo,
4186    response: Response<proto::GetPrivateUserInfo>,
4187    session: Session,
4188) -> Result<()> {
4189    let db = session.db().await;
4190
4191    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4192    let user = db
4193        .get_user_by_id(session.user_id())
4194        .await?
4195        .context("user not found")?;
4196    let flags = db.get_user_flags(session.user_id()).await?;
4197
4198    response.send(proto::GetPrivateUserInfoResponse {
4199        metrics_id,
4200        staff: user.admin,
4201        flags,
4202        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4203    })?;
4204    Ok(())
4205}
4206
4207/// Accept the terms of service (tos) on behalf of the current user
4208async fn accept_terms_of_service(
4209    _request: proto::AcceptTermsOfService,
4210    response: Response<proto::AcceptTermsOfService>,
4211    session: Session,
4212) -> Result<()> {
4213    let db = session.db().await;
4214
4215    let accepted_tos_at = Utc::now();
4216    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4217        .await?;
4218
4219    response.send(proto::AcceptTermsOfServiceResponse {
4220        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4221    })?;
4222
4223    // When the user accepts the terms of service, we want to refresh their LLM
4224    // token to grant access.
4225    session
4226        .peer
4227        .send(session.connection_id, proto::RefreshLlmToken {})?;
4228
4229    Ok(())
4230}
4231
4232async fn get_llm_api_token(
4233    _request: proto::GetLlmToken,
4234    response: Response<proto::GetLlmToken>,
4235    session: Session,
4236) -> Result<()> {
4237    let db = session.db().await;
4238
4239    let flags = db.get_user_flags(session.user_id()).await?;
4240
4241    let user_id = session.user_id();
4242    let user = db
4243        .get_user_by_id(user_id)
4244        .await?
4245        .with_context(|| format!("user {user_id} not found"))?;
4246
4247    if user.accepted_tos_at.is_none() {
4248        Err(anyhow!("terms of service not accepted"))?
4249    }
4250
4251    let stripe_client = session
4252        .app_state
4253        .stripe_client
4254        .as_ref()
4255        .context("failed to retrieve Stripe client")?;
4256
4257    let stripe_billing = session
4258        .app_state
4259        .stripe_billing
4260        .as_ref()
4261        .context("failed to retrieve Stripe billing object")?;
4262
4263    let billing_customer = if let Some(billing_customer) =
4264        db.get_billing_customer_by_user_id(user.id).await?
4265    {
4266        billing_customer
4267    } else {
4268        let customer_id = stripe_billing
4269            .find_or_create_customer_by_email(user.email_address.as_deref())
4270            .await?;
4271
4272        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4273            .await?
4274            .context("billing customer not found")?
4275    };
4276
4277    let billing_subscription =
4278        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4279            billing_subscription
4280        } else {
4281            let stripe_customer_id =
4282                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4283
4284            let stripe_subscription = stripe_billing
4285                .subscribe_to_zed_free(stripe_customer_id)
4286                .await?;
4287
4288            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4289                billing_customer_id: billing_customer.id,
4290                kind: Some(SubscriptionKind::ZedFree),
4291                stripe_subscription_id: stripe_subscription.id.to_string(),
4292                stripe_subscription_status: stripe_subscription.status.into(),
4293                stripe_cancellation_reason: None,
4294                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4295                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4296            })
4297            .await?
4298        };
4299
4300    let billing_preferences = db.get_billing_preferences(user.id).await?;
4301
4302    let token = LlmTokenClaims::create(
4303        &user,
4304        session.is_staff(),
4305        billing_customer,
4306        billing_preferences,
4307        &flags,
4308        billing_subscription,
4309        session.system_id.clone(),
4310        &session.app_state.config,
4311    )?;
4312    response.send(proto::GetLlmTokenResponse { token })?;
4313    Ok(())
4314}
4315
4316fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4317    let message = match message {
4318        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4319        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4320        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4321        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4322        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4323            code: frame.code.into(),
4324            reason: frame.reason.as_str().to_owned().into(),
4325        })),
4326        // We should never receive a frame while reading the message, according
4327        // to the `tungstenite` maintainers:
4328        //
4329        // > It cannot occur when you read messages from the WebSocket, but it
4330        // > can be used when you want to send the raw frames (e.g. you want to
4331        // > send the frames to the WebSocket without composing the full message first).
4332        // >
4333        // > — https://github.com/snapview/tungstenite-rs/issues/268
4334        TungsteniteMessage::Frame(_) => {
4335            bail!("received an unexpected frame while reading the message")
4336        }
4337    };
4338
4339    Ok(message)
4340}
4341
4342fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4343    match message {
4344        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4345        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4346        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4347        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4348        AxumMessage::Close(frame) => {
4349            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4350                code: frame.code.into(),
4351                reason: frame.reason.as_ref().into(),
4352            }))
4353        }
4354    }
4355}
4356
4357fn notify_membership_updated(
4358    connection_pool: &mut ConnectionPool,
4359    result: MembershipUpdated,
4360    user_id: UserId,
4361    peer: &Peer,
4362) {
4363    for membership in &result.new_channels.channel_memberships {
4364        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4365    }
4366    for channel_id in &result.removed_channels {
4367        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4368    }
4369
4370    let user_channels_update = proto::UpdateUserChannels {
4371        channel_memberships: result
4372            .new_channels
4373            .channel_memberships
4374            .iter()
4375            .map(|cm| proto::ChannelMembership {
4376                channel_id: cm.channel_id.to_proto(),
4377                role: cm.role.into(),
4378            })
4379            .collect(),
4380        ..Default::default()
4381    };
4382
4383    let mut update = build_channels_update(result.new_channels);
4384    update.delete_channels = result
4385        .removed_channels
4386        .into_iter()
4387        .map(|id| id.to_proto())
4388        .collect();
4389    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4390
4391    for connection_id in connection_pool.user_connection_ids(user_id) {
4392        peer.send(connection_id, user_channels_update.clone())
4393            .trace_err();
4394        peer.send(connection_id, update.clone()).trace_err();
4395    }
4396}
4397
4398fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4399    proto::UpdateUserChannels {
4400        channel_memberships: channels
4401            .channel_memberships
4402            .iter()
4403            .map(|m| proto::ChannelMembership {
4404                channel_id: m.channel_id.to_proto(),
4405                role: m.role.into(),
4406            })
4407            .collect(),
4408        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4409        observed_channel_message_id: channels.observed_channel_messages.clone(),
4410    }
4411}
4412
4413fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4414    let mut update = proto::UpdateChannels::default();
4415
4416    for channel in channels.channels {
4417        update.channels.push(channel.to_proto());
4418    }
4419
4420    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4421    update.latest_channel_message_ids = channels.latest_channel_messages;
4422
4423    for (channel_id, participants) in channels.channel_participants {
4424        update
4425            .channel_participants
4426            .push(proto::ChannelParticipants {
4427                channel_id: channel_id.to_proto(),
4428                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4429            });
4430    }
4431
4432    for channel in channels.invited_channels {
4433        update.channel_invitations.push(channel.to_proto());
4434    }
4435
4436    update
4437}
4438
4439fn build_initial_contacts_update(
4440    contacts: Vec<db::Contact>,
4441    pool: &ConnectionPool,
4442) -> proto::UpdateContacts {
4443    let mut update = proto::UpdateContacts::default();
4444
4445    for contact in contacts {
4446        match contact {
4447            db::Contact::Accepted { user_id, busy } => {
4448                update.contacts.push(contact_for_user(user_id, busy, pool));
4449            }
4450            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4451            db::Contact::Incoming { user_id } => {
4452                update
4453                    .incoming_requests
4454                    .push(proto::IncomingContactRequest {
4455                        requester_id: user_id.to_proto(),
4456                    })
4457            }
4458        }
4459    }
4460
4461    update
4462}
4463
4464fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4465    proto::Contact {
4466        user_id: user_id.to_proto(),
4467        online: pool.is_user_online(user_id),
4468        busy,
4469    }
4470}
4471
4472fn room_updated(room: &proto::Room, peer: &Peer) {
4473    broadcast(
4474        None,
4475        room.participants
4476            .iter()
4477            .filter_map(|participant| Some(participant.peer_id?.into())),
4478        |peer_id| {
4479            peer.send(
4480                peer_id,
4481                proto::RoomUpdated {
4482                    room: Some(room.clone()),
4483                },
4484            )
4485        },
4486    );
4487}
4488
4489fn channel_updated(
4490    channel: &db::channel::Model,
4491    room: &proto::Room,
4492    peer: &Peer,
4493    pool: &ConnectionPool,
4494) {
4495    let participants = room
4496        .participants
4497        .iter()
4498        .map(|p| p.user_id)
4499        .collect::<Vec<_>>();
4500
4501    broadcast(
4502        None,
4503        pool.channel_connection_ids(channel.root_id())
4504            .filter_map(|(channel_id, role)| {
4505                role.can_see_channel(channel.visibility)
4506                    .then_some(channel_id)
4507            }),
4508        |peer_id| {
4509            peer.send(
4510                peer_id,
4511                proto::UpdateChannels {
4512                    channel_participants: vec![proto::ChannelParticipants {
4513                        channel_id: channel.id.to_proto(),
4514                        participant_user_ids: participants.clone(),
4515                    }],
4516                    ..Default::default()
4517                },
4518            )
4519        },
4520    );
4521}
4522
4523async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4524    let db = session.db().await;
4525
4526    let contacts = db.get_contacts(user_id).await?;
4527    let busy = db.is_user_busy(user_id).await?;
4528
4529    let pool = session.connection_pool().await;
4530    let updated_contact = contact_for_user(user_id, busy, &pool);
4531    for contact in contacts {
4532        if let db::Contact::Accepted {
4533            user_id: contact_user_id,
4534            ..
4535        } = contact
4536        {
4537            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4538                session
4539                    .peer
4540                    .send(
4541                        contact_conn_id,
4542                        proto::UpdateContacts {
4543                            contacts: vec![updated_contact.clone()],
4544                            remove_contacts: Default::default(),
4545                            incoming_requests: Default::default(),
4546                            remove_incoming_requests: Default::default(),
4547                            outgoing_requests: Default::default(),
4548                            remove_outgoing_requests: Default::default(),
4549                        },
4550                    )
4551                    .trace_err();
4552            }
4553        }
4554    }
4555    Ok(())
4556}
4557
4558async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4559    let mut contacts_to_update = HashSet::default();
4560
4561    let room_id;
4562    let canceled_calls_to_user_ids;
4563    let livekit_room;
4564    let delete_livekit_room;
4565    let room;
4566    let channel;
4567
4568    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4569        contacts_to_update.insert(session.user_id());
4570
4571        for project in left_room.left_projects.values() {
4572            project_left(project, session);
4573        }
4574
4575        room_id = RoomId::from_proto(left_room.room.id);
4576        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4577        livekit_room = mem::take(&mut left_room.room.livekit_room);
4578        delete_livekit_room = left_room.deleted;
4579        room = mem::take(&mut left_room.room);
4580        channel = mem::take(&mut left_room.channel);
4581
4582        room_updated(&room, &session.peer);
4583    } else {
4584        return Ok(());
4585    }
4586
4587    if let Some(channel) = channel {
4588        channel_updated(
4589            &channel,
4590            &room,
4591            &session.peer,
4592            &*session.connection_pool().await,
4593        );
4594    }
4595
4596    {
4597        let pool = session.connection_pool().await;
4598        for canceled_user_id in canceled_calls_to_user_ids {
4599            for connection_id in pool.user_connection_ids(canceled_user_id) {
4600                session
4601                    .peer
4602                    .send(
4603                        connection_id,
4604                        proto::CallCanceled {
4605                            room_id: room_id.to_proto(),
4606                        },
4607                    )
4608                    .trace_err();
4609            }
4610            contacts_to_update.insert(canceled_user_id);
4611        }
4612    }
4613
4614    for contact_user_id in contacts_to_update {
4615        update_user_contacts(contact_user_id, session).await?;
4616    }
4617
4618    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4619        live_kit
4620            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4621            .await
4622            .trace_err();
4623
4624        if delete_livekit_room {
4625            live_kit.delete_room(livekit_room).await.trace_err();
4626        }
4627    }
4628
4629    Ok(())
4630}
4631
4632async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4633    let left_channel_buffers = session
4634        .db()
4635        .await
4636        .leave_channel_buffers(session.connection_id)
4637        .await?;
4638
4639    for left_buffer in left_channel_buffers {
4640        channel_buffer_updated(
4641            session.connection_id,
4642            left_buffer.connections,
4643            &proto::UpdateChannelBufferCollaborators {
4644                channel_id: left_buffer.channel_id.to_proto(),
4645                collaborators: left_buffer.collaborators,
4646            },
4647            &session.peer,
4648        );
4649    }
4650
4651    Ok(())
4652}
4653
4654fn project_left(project: &db::LeftProject, session: &Session) {
4655    for connection_id in &project.connection_ids {
4656        if project.should_unshare {
4657            session
4658                .peer
4659                .send(
4660                    *connection_id,
4661                    proto::UnshareProject {
4662                        project_id: project.id.to_proto(),
4663                    },
4664                )
4665                .trace_err();
4666        } else {
4667            session
4668                .peer
4669                .send(
4670                    *connection_id,
4671                    proto::RemoveProjectCollaborator {
4672                        project_id: project.id.to_proto(),
4673                        peer_id: Some(session.connection_id.into()),
4674                    },
4675                )
4676                .trace_err();
4677        }
4678    }
4679}
4680
4681pub trait ResultExt {
4682    type Ok;
4683
4684    fn trace_err(self) -> Option<Self::Ok>;
4685}
4686
4687impl<T, E> ResultExt for Result<T, E>
4688where
4689    E: std::fmt::Debug,
4690{
4691    type Ok = T;
4692
4693    #[track_caller]
4694    fn trace_err(self) -> Option<T> {
4695        match self {
4696            Ok(value) => Some(value),
4697            Err(error) => {
4698                tracing::error!("{:?}", error);
4699                None
4700            }
4701        }
4702    }
4703}