rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{
   8    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
   9    MIN_ACCOUNT_AGE_FOR_LLM_USE,
  10};
  11use crate::stripe_client::StripeCustomerId;
  12use crate::{
  13    AppState, Error, Result, auth,
  14    db::{
  15        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  16        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  17        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  18        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  19    },
  20    executor::Executor,
  21};
  22use anyhow::{Context as _, anyhow, bail};
  23use async_tungstenite::tungstenite::{
  24    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  25};
  26use axum::headers::UserAgent;
  27use axum::{
  28    Extension, Router, TypedHeader,
  29    body::Body,
  30    extract::{
  31        ConnectInfo, WebSocketUpgrade,
  32        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  33    },
  34    headers::{Header, HeaderName},
  35    http::StatusCode,
  36    middleware,
  37    response::IntoResponse,
  38    routing::get,
  39};
  40use chrono::Utc;
  41use collections::{HashMap, HashSet};
  42pub use connection_pool::{ConnectionPool, ZedVersion};
  43use core::fmt::{self, Debug, Formatter};
  44use reqwest_client::ReqwestClient;
  45use rpc::proto::{MultiLspQuery, split_repository_update};
  46use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  47
  48use futures::{
  49    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  50    stream::FuturesUnordered,
  51};
  52use prometheus::{IntGauge, register_int_gauge};
  53use rpc::{
  54    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  55    proto::{
  56        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  57        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  58    },
  59};
  60use semantic_version::SemanticVersion;
  61use serde::{Serialize, Serializer};
  62use std::{
  63    any::TypeId,
  64    future::Future,
  65    marker::PhantomData,
  66    mem,
  67    net::SocketAddr,
  68    ops::{Deref, DerefMut},
  69    rc::Rc,
  70    sync::{
  71        Arc, OnceLock,
  72        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  73    },
  74    time::{Duration, Instant},
  75};
  76use time::OffsetDateTime;
  77use tokio::sync::{Semaphore, watch};
  78use tower::ServiceBuilder;
  79use tracing::{
  80    Instrument,
  81    field::{self},
  82    info_span, instrument,
  83};
  84
  85pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  86
  87// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  88pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  89
  90const MESSAGE_COUNT_PER_PAGE: usize = 100;
  91const MAX_MESSAGE_LEN: usize = 1024;
  92const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  93const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  94
  95static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  96
  97type MessageHandler =
  98    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  99
 100pub struct ConnectionGuard;
 101
 102impl ConnectionGuard {
 103    pub fn try_acquire() -> Result<Self, ()> {
 104        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 105        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 106            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 107            tracing::error!(
 108                "too many concurrent connections: {}",
 109                current_connections + 1
 110            );
 111            return Err(());
 112        }
 113        Ok(ConnectionGuard)
 114    }
 115}
 116
 117impl Drop for ConnectionGuard {
 118    fn drop(&mut self) {
 119        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 120    }
 121}
 122
 123struct Response<R> {
 124    peer: Arc<Peer>,
 125    receipt: Receipt<R>,
 126    responded: Arc<AtomicBool>,
 127}
 128
 129impl<R: RequestMessage> Response<R> {
 130    fn send(self, payload: R::Response) -> Result<()> {
 131        self.responded.store(true, SeqCst);
 132        self.peer.respond(self.receipt, payload)?;
 133        Ok(())
 134    }
 135}
 136
 137#[derive(Clone, Debug)]
 138pub enum Principal {
 139    User(User),
 140    Impersonated { user: User, admin: User },
 141}
 142
 143impl Principal {
 144    fn user(&self) -> &User {
 145        match self {
 146            Principal::User(user) => user,
 147            Principal::Impersonated { user, .. } => user,
 148        }
 149    }
 150
 151    fn update_span(&self, span: &tracing::Span) {
 152        match &self {
 153            Principal::User(user) => {
 154                span.record("user_id", user.id.0);
 155                span.record("login", &user.github_login);
 156            }
 157            Principal::Impersonated { user, admin } => {
 158                span.record("user_id", user.id.0);
 159                span.record("login", &user.github_login);
 160                span.record("impersonator", &admin.github_login);
 161            }
 162        }
 163    }
 164}
 165
 166#[derive(Clone)]
 167struct Session {
 168    principal: Principal,
 169    connection_id: ConnectionId,
 170    db: Arc<tokio::sync::Mutex<DbHandle>>,
 171    peer: Arc<Peer>,
 172    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 173    app_state: Arc<AppState>,
 174    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 175    /// The GeoIP country code for the user.
 176    #[allow(unused)]
 177    geoip_country_code: Option<String>,
 178    system_id: Option<String>,
 179    _executor: Executor,
 180}
 181
 182impl Session {
 183    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 184        #[cfg(test)]
 185        tokio::task::yield_now().await;
 186        let guard = self.db.lock().await;
 187        #[cfg(test)]
 188        tokio::task::yield_now().await;
 189        guard
 190    }
 191
 192    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 193        #[cfg(test)]
 194        tokio::task::yield_now().await;
 195        let guard = self.connection_pool.lock();
 196        ConnectionPoolGuard {
 197            guard,
 198            _not_send: PhantomData,
 199        }
 200    }
 201
 202    fn is_staff(&self) -> bool {
 203        match &self.principal {
 204            Principal::User(user) => user.admin,
 205            Principal::Impersonated { .. } => true,
 206        }
 207    }
 208
 209    fn user_id(&self) -> UserId {
 210        match &self.principal {
 211            Principal::User(user) => user.id,
 212            Principal::Impersonated { user, .. } => user.id,
 213        }
 214    }
 215
 216    pub fn email(&self) -> Option<String> {
 217        match &self.principal {
 218            Principal::User(user) => user.email_address.clone(),
 219            Principal::Impersonated { user, .. } => user.email_address.clone(),
 220        }
 221    }
 222}
 223
 224impl Debug for Session {
 225    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 226        let mut result = f.debug_struct("Session");
 227        match &self.principal {
 228            Principal::User(user) => {
 229                result.field("user", &user.github_login);
 230            }
 231            Principal::Impersonated { user, admin } => {
 232                result.field("user", &user.github_login);
 233                result.field("impersonator", &admin.github_login);
 234            }
 235        }
 236        result.field("connection_id", &self.connection_id).finish()
 237    }
 238}
 239
 240struct DbHandle(Arc<Database>);
 241
 242impl Deref for DbHandle {
 243    type Target = Database;
 244
 245    fn deref(&self) -> &Self::Target {
 246        self.0.as_ref()
 247    }
 248}
 249
 250pub struct Server {
 251    id: parking_lot::Mutex<ServerId>,
 252    peer: Arc<Peer>,
 253    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 254    app_state: Arc<AppState>,
 255    handlers: HashMap<TypeId, MessageHandler>,
 256    teardown: watch::Sender<bool>,
 257}
 258
 259pub(crate) struct ConnectionPoolGuard<'a> {
 260    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 261    _not_send: PhantomData<Rc<()>>,
 262}
 263
 264#[derive(Serialize)]
 265pub struct ServerSnapshot<'a> {
 266    peer: &'a Peer,
 267    #[serde(serialize_with = "serialize_deref")]
 268    connection_pool: ConnectionPoolGuard<'a>,
 269}
 270
 271pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 272where
 273    S: Serializer,
 274    T: Deref<Target = U>,
 275    U: Serialize,
 276{
 277    Serialize::serialize(value.deref(), serializer)
 278}
 279
 280impl Server {
 281    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 282        let mut server = Self {
 283            id: parking_lot::Mutex::new(id),
 284            peer: Peer::new(id.0 as u32),
 285            app_state: app_state.clone(),
 286            connection_pool: Default::default(),
 287            handlers: Default::default(),
 288            teardown: watch::channel(false).0,
 289        };
 290
 291        server
 292            .add_request_handler(ping)
 293            .add_request_handler(create_room)
 294            .add_request_handler(join_room)
 295            .add_request_handler(rejoin_room)
 296            .add_request_handler(leave_room)
 297            .add_request_handler(set_room_participant_role)
 298            .add_request_handler(call)
 299            .add_request_handler(cancel_call)
 300            .add_message_handler(decline_call)
 301            .add_request_handler(update_participant_location)
 302            .add_request_handler(share_project)
 303            .add_message_handler(unshare_project)
 304            .add_request_handler(join_project)
 305            .add_message_handler(leave_project)
 306            .add_request_handler(update_project)
 307            .add_request_handler(update_worktree)
 308            .add_request_handler(update_repository)
 309            .add_request_handler(remove_repository)
 310            .add_message_handler(start_language_server)
 311            .add_message_handler(update_language_server)
 312            .add_message_handler(update_diagnostic_summary)
 313            .add_message_handler(update_worktree_settings)
 314            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 316            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 317            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 318            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 319            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 320            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 321            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 322            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 323            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 324            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 325            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 326            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 327            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 328            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 329            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 330            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 331            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 332            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 333            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 334            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 335            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 336            .add_request_handler(
 337                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 338            )
 339            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 340            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 341            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 342            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 343            .add_request_handler(
 344                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 345            )
 346            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 347            .add_request_handler(
 348                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 349            )
 350            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 351            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 352            .add_request_handler(
 353                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 354            )
 355            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 356            .add_request_handler(
 357                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 358            )
 359            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 360            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 361            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 362            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 363            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 364            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 365            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 366            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 367            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 368            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 369            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 370            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 371            .add_request_handler(
 372                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 373            )
 374            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 375            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 376            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 377            .add_request_handler(multi_lsp_query)
 378            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 379            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 380            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 381            .add_message_handler(create_buffer_for_peer)
 382            .add_request_handler(update_buffer)
 383            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 384            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 385            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 386            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 387            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 388            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 389            .add_message_handler(
 390                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 391            )
 392            .add_request_handler(get_users)
 393            .add_request_handler(fuzzy_search_users)
 394            .add_request_handler(request_contact)
 395            .add_request_handler(remove_contact)
 396            .add_request_handler(respond_to_contact_request)
 397            .add_message_handler(subscribe_to_channels)
 398            .add_request_handler(create_channel)
 399            .add_request_handler(delete_channel)
 400            .add_request_handler(invite_channel_member)
 401            .add_request_handler(remove_channel_member)
 402            .add_request_handler(set_channel_member_role)
 403            .add_request_handler(set_channel_visibility)
 404            .add_request_handler(rename_channel)
 405            .add_request_handler(join_channel_buffer)
 406            .add_request_handler(leave_channel_buffer)
 407            .add_message_handler(update_channel_buffer)
 408            .add_request_handler(rejoin_channel_buffers)
 409            .add_request_handler(get_channel_members)
 410            .add_request_handler(respond_to_channel_invite)
 411            .add_request_handler(join_channel)
 412            .add_request_handler(join_channel_chat)
 413            .add_message_handler(leave_channel_chat)
 414            .add_request_handler(send_channel_message)
 415            .add_request_handler(remove_channel_message)
 416            .add_request_handler(update_channel_message)
 417            .add_request_handler(get_channel_messages)
 418            .add_request_handler(get_channel_messages_by_id)
 419            .add_request_handler(get_notifications)
 420            .add_request_handler(mark_notification_as_read)
 421            .add_request_handler(move_channel)
 422            .add_request_handler(reorder_channel)
 423            .add_request_handler(follow)
 424            .add_message_handler(unfollow)
 425            .add_message_handler(update_followers)
 426            .add_request_handler(get_private_user_info)
 427            .add_request_handler(get_llm_api_token)
 428            .add_request_handler(accept_terms_of_service)
 429            .add_message_handler(acknowledge_channel_message)
 430            .add_message_handler(acknowledge_buffer_version)
 431            .add_request_handler(get_supermaven_api_key)
 432            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 433            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 434            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 435            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 436            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 437            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 438            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 439            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 440            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 441            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 442            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 443            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 444            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 445            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 446            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 447            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 448            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 449            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 450            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 451            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 452            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 453            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 454            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 455            .add_message_handler(update_context);
 456
 457        Arc::new(server)
 458    }
 459
 460    pub async fn start(&self) -> Result<()> {
 461        let server_id = *self.id.lock();
 462        let app_state = self.app_state.clone();
 463        let peer = self.peer.clone();
 464        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 465        let pool = self.connection_pool.clone();
 466        let livekit_client = self.app_state.livekit_client.clone();
 467
 468        let span = info_span!("start server");
 469        self.app_state.executor.spawn_detached(
 470            async move {
 471                tracing::info!("waiting for cleanup timeout");
 472                timeout.await;
 473                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 474
 475                app_state
 476                    .db
 477                    .delete_stale_channel_chat_participants(
 478                        &app_state.config.zed_environment,
 479                        server_id,
 480                    )
 481                    .await
 482                    .trace_err();
 483
 484                if let Some((room_ids, channel_ids)) = app_state
 485                    .db
 486                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 487                    .await
 488                    .trace_err()
 489                {
 490                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 491                    tracing::info!(
 492                        stale_channel_buffer_count = channel_ids.len(),
 493                        "retrieved stale channel buffers"
 494                    );
 495
 496                    for channel_id in channel_ids {
 497                        if let Some(refreshed_channel_buffer) = app_state
 498                            .db
 499                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 500                            .await
 501                            .trace_err()
 502                        {
 503                            for connection_id in refreshed_channel_buffer.connection_ids {
 504                                peer.send(
 505                                    connection_id,
 506                                    proto::UpdateChannelBufferCollaborators {
 507                                        channel_id: channel_id.to_proto(),
 508                                        collaborators: refreshed_channel_buffer
 509                                            .collaborators
 510                                            .clone(),
 511                                    },
 512                                )
 513                                .trace_err();
 514                            }
 515                        }
 516                    }
 517
 518                    for room_id in room_ids {
 519                        let mut contacts_to_update = HashSet::default();
 520                        let mut canceled_calls_to_user_ids = Vec::new();
 521                        let mut livekit_room = String::new();
 522                        let mut delete_livekit_room = false;
 523
 524                        if let Some(mut refreshed_room) = app_state
 525                            .db
 526                            .clear_stale_room_participants(room_id, server_id)
 527                            .await
 528                            .trace_err()
 529                        {
 530                            tracing::info!(
 531                                room_id = room_id.0,
 532                                new_participant_count = refreshed_room.room.participants.len(),
 533                                "refreshed room"
 534                            );
 535                            room_updated(&refreshed_room.room, &peer);
 536                            if let Some(channel) = refreshed_room.channel.as_ref() {
 537                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 538                            }
 539                            contacts_to_update
 540                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 541                            contacts_to_update
 542                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 543                            canceled_calls_to_user_ids =
 544                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 545                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 546                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 547                        }
 548
 549                        {
 550                            let pool = pool.lock();
 551                            for canceled_user_id in canceled_calls_to_user_ids {
 552                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 553                                    peer.send(
 554                                        connection_id,
 555                                        proto::CallCanceled {
 556                                            room_id: room_id.to_proto(),
 557                                        },
 558                                    )
 559                                    .trace_err();
 560                                }
 561                            }
 562                        }
 563
 564                        for user_id in contacts_to_update {
 565                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 566                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 567                            if let Some((busy, contacts)) = busy.zip(contacts) {
 568                                let pool = pool.lock();
 569                                let updated_contact = contact_for_user(user_id, busy, &pool);
 570                                for contact in contacts {
 571                                    if let db::Contact::Accepted {
 572                                        user_id: contact_user_id,
 573                                        ..
 574                                    } = contact
 575                                    {
 576                                        for contact_conn_id in
 577                                            pool.user_connection_ids(contact_user_id)
 578                                        {
 579                                            peer.send(
 580                                                contact_conn_id,
 581                                                proto::UpdateContacts {
 582                                                    contacts: vec![updated_contact.clone()],
 583                                                    remove_contacts: Default::default(),
 584                                                    incoming_requests: Default::default(),
 585                                                    remove_incoming_requests: Default::default(),
 586                                                    outgoing_requests: Default::default(),
 587                                                    remove_outgoing_requests: Default::default(),
 588                                                },
 589                                            )
 590                                            .trace_err();
 591                                        }
 592                                    }
 593                                }
 594                            }
 595                        }
 596
 597                        if let Some(live_kit) = livekit_client.as_ref() {
 598                            if delete_livekit_room {
 599                                live_kit.delete_room(livekit_room).await.trace_err();
 600                            }
 601                        }
 602                    }
 603                }
 604
 605                app_state
 606                    .db
 607                    .delete_stale_channel_chat_participants(
 608                        &app_state.config.zed_environment,
 609                        server_id,
 610                    )
 611                    .await
 612                    .trace_err();
 613
 614                app_state
 615                    .db
 616                    .clear_old_worktree_entries(server_id)
 617                    .await
 618                    .trace_err();
 619
 620                app_state
 621                    .db
 622                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 623                    .await
 624                    .trace_err();
 625            }
 626            .instrument(span),
 627        );
 628        Ok(())
 629    }
 630
 631    pub fn teardown(&self) {
 632        self.peer.teardown();
 633        self.connection_pool.lock().reset();
 634        let _ = self.teardown.send(true);
 635    }
 636
 637    #[cfg(test)]
 638    pub fn reset(&self, id: ServerId) {
 639        self.teardown();
 640        *self.id.lock() = id;
 641        self.peer.reset(id.0 as u32);
 642        let _ = self.teardown.send(false);
 643    }
 644
 645    #[cfg(test)]
 646    pub fn id(&self) -> ServerId {
 647        *self.id.lock()
 648    }
 649
 650    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 651    where
 652        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 653        Fut: 'static + Send + Future<Output = Result<()>>,
 654        M: EnvelopedMessage,
 655    {
 656        let prev_handler = self.handlers.insert(
 657            TypeId::of::<M>(),
 658            Box::new(move |envelope, session| {
 659                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 660                let received_at = envelope.received_at;
 661                tracing::info!("message received");
 662                let start_time = Instant::now();
 663                let future = (handler)(*envelope, session);
 664                async move {
 665                    let result = future.await;
 666                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 667                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 668                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 669
 670                    match result {
 671                        Err(error) => {
 672                            tracing::error!(
 673                                ?error,
 674                                total_duration_ms,
 675                                processing_duration_ms,
 676                                queue_duration_ms,
 677                                "error handling message"
 678                            )
 679                        }
 680                        Ok(()) => tracing::info!(
 681                            total_duration_ms,
 682                            processing_duration_ms,
 683                            queue_duration_ms,
 684                            "finished handling message"
 685                        ),
 686                    }
 687                }
 688                .boxed()
 689            }),
 690        );
 691        if prev_handler.is_some() {
 692            panic!("registered a handler for the same message twice");
 693        }
 694        self
 695    }
 696
 697    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 698    where
 699        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 700        Fut: 'static + Send + Future<Output = Result<()>>,
 701        M: EnvelopedMessage,
 702    {
 703        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 704        self
 705    }
 706
 707    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 708    where
 709        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 710        Fut: Send + Future<Output = Result<()>>,
 711        M: RequestMessage,
 712    {
 713        let handler = Arc::new(handler);
 714        self.add_handler(move |envelope, session| {
 715            let receipt = envelope.receipt();
 716            let handler = handler.clone();
 717            async move {
 718                let peer = session.peer.clone();
 719                let responded = Arc::new(AtomicBool::default());
 720                let response = Response {
 721                    peer: peer.clone(),
 722                    responded: responded.clone(),
 723                    receipt,
 724                };
 725                match (handler)(envelope.payload, response, session).await {
 726                    Ok(()) => {
 727                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 728                            Ok(())
 729                        } else {
 730                            Err(anyhow!("handler did not send a response"))?
 731                        }
 732                    }
 733                    Err(error) => {
 734                        let proto_err = match &error {
 735                            Error::Internal(err) => err.to_proto(),
 736                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 737                        };
 738                        peer.respond_with_error(receipt, proto_err)?;
 739                        Err(error)
 740                    }
 741                }
 742            }
 743        })
 744    }
 745
 746    pub fn handle_connection(
 747        self: &Arc<Self>,
 748        connection: Connection,
 749        address: String,
 750        principal: Principal,
 751        zed_version: ZedVersion,
 752        user_agent: Option<String>,
 753        geoip_country_code: Option<String>,
 754        system_id: Option<String>,
 755        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 756        executor: Executor,
 757        connection_guard: Option<ConnectionGuard>,
 758    ) -> impl Future<Output = ()> + use<> {
 759        let this = self.clone();
 760        let span = info_span!("handle connection", %address,
 761            connection_id=field::Empty,
 762            user_id=field::Empty,
 763            login=field::Empty,
 764            impersonator=field::Empty,
 765            user_agent=field::Empty,
 766            geoip_country_code=field::Empty
 767        );
 768        principal.update_span(&span);
 769        if let Some(user_agent) = user_agent {
 770            span.record("user_agent", user_agent);
 771        }
 772
 773        if let Some(country_code) = geoip_country_code.as_ref() {
 774            span.record("geoip_country_code", country_code);
 775        }
 776
 777        let mut teardown = self.teardown.subscribe();
 778        async move {
 779            if *teardown.borrow() {
 780                tracing::error!("server is tearing down");
 781                return;
 782            }
 783
 784            let (connection_id, handle_io, mut incoming_rx) =
 785                this.peer.add_connection(connection, {
 786                    let executor = executor.clone();
 787                    move |duration| executor.sleep(duration)
 788                });
 789            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 790
 791            tracing::info!("connection opened");
 792
 793            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 794            let http_client = match ReqwestClient::user_agent(&user_agent) {
 795                Ok(http_client) => Arc::new(http_client),
 796                Err(error) => {
 797                    tracing::error!(?error, "failed to create HTTP client");
 798                    return;
 799                }
 800            };
 801
 802            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 803                |supermaven_admin_api_key| {
 804                    Arc::new(SupermavenAdminApi::new(
 805                        supermaven_admin_api_key.to_string(),
 806                        http_client.clone(),
 807                    ))
 808                },
 809            );
 810
 811            let session = Session {
 812                principal: principal.clone(),
 813                connection_id,
 814                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 815                peer: this.peer.clone(),
 816                connection_pool: this.connection_pool.clone(),
 817                app_state: this.app_state.clone(),
 818                geoip_country_code,
 819                system_id,
 820                _executor: executor.clone(),
 821                supermaven_client,
 822            };
 823
 824            if let Err(error) = this
 825                .send_initial_client_update(
 826                    connection_id,
 827                    zed_version,
 828                    send_connection_id,
 829                    &session,
 830                )
 831                .await
 832            {
 833                tracing::error!(?error, "failed to send initial client update");
 834                return;
 835            }
 836            drop(connection_guard);
 837
 838            let handle_io = handle_io.fuse();
 839            futures::pin_mut!(handle_io);
 840
 841            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 842            // This prevents deadlocks when e.g., client A performs a request to client B and
 843            // client B performs a request to client A. If both clients stop processing further
 844            // messages until their respective request completes, they won't have a chance to
 845            // respond to the other client's request and cause a deadlock.
 846            //
 847            // This arrangement ensures we will attempt to process earlier messages first, but fall
 848            // back to processing messages arrived later in the spirit of making progress.
 849            const MAX_CONCURRENT_HANDLERS: usize = 256;
 850            let mut foreground_message_handlers = FuturesUnordered::new();
 851            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 852            let get_concurrent_handlers = {
 853                let concurrent_handlers = concurrent_handlers.clone();
 854                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 855            };
 856            loop {
 857                let next_message = async {
 858                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 859                    let message = incoming_rx.next().await;
 860                    // Cache the concurrent_handlers here, so that we know what the
 861                    // queue looks like as each handler starts
 862                    (permit, message, get_concurrent_handlers())
 863                }
 864                .fuse();
 865                futures::pin_mut!(next_message);
 866                futures::select_biased! {
 867                    _ = teardown.changed().fuse() => return,
 868                    result = handle_io => {
 869                        if let Err(error) = result {
 870                            tracing::error!(?error, "error handling I/O");
 871                        }
 872                        break;
 873                    }
 874                    _ = foreground_message_handlers.next() => {}
 875                    next_message = next_message => {
 876                        let (permit, message, concurrent_handlers) = next_message;
 877                        if let Some(message) = message {
 878                            let type_name = message.payload_type_name();
 879                            // note: we copy all the fields from the parent span so we can query them in the logs.
 880                            // (https://github.com/tokio-rs/tracing/issues/2670).
 881                            let span = tracing::info_span!("receive message",
 882                                %connection_id,
 883                                %address,
 884                                type_name,
 885                                concurrent_handlers,
 886                                user_id=field::Empty,
 887                                login=field::Empty,
 888                                impersonator=field::Empty,
 889                                multi_lsp_query_request=field::Empty,
 890                            );
 891                            principal.update_span(&span);
 892                            let span_enter = span.enter();
 893                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 894                                let is_background = message.is_background();
 895                                let handle_message = (handler)(message, session.clone());
 896                                drop(span_enter);
 897
 898                                let handle_message = async move {
 899                                    handle_message.await;
 900                                    drop(permit);
 901                                }.instrument(span);
 902                                if is_background {
 903                                    executor.spawn_detached(handle_message);
 904                                } else {
 905                                    foreground_message_handlers.push(handle_message);
 906                                }
 907                            } else {
 908                                tracing::error!("no message handler");
 909                            }
 910                        } else {
 911                            tracing::info!("connection closed");
 912                            break;
 913                        }
 914                    }
 915                }
 916            }
 917
 918            drop(foreground_message_handlers);
 919            let concurrent_handlers = get_concurrent_handlers();
 920            tracing::info!(concurrent_handlers, "signing out");
 921            if let Err(error) = connection_lost(session, teardown, executor).await {
 922                tracing::error!(?error, "error signing out");
 923            }
 924        }
 925        .instrument(span)
 926    }
 927
 928    async fn send_initial_client_update(
 929        &self,
 930        connection_id: ConnectionId,
 931        zed_version: ZedVersion,
 932        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 933        session: &Session,
 934    ) -> Result<()> {
 935        self.peer.send(
 936            connection_id,
 937            proto::Hello {
 938                peer_id: Some(connection_id.into()),
 939            },
 940        )?;
 941        tracing::info!("sent hello message");
 942        if let Some(send_connection_id) = send_connection_id.take() {
 943            let _ = send_connection_id.send(connection_id);
 944        }
 945
 946        match &session.principal {
 947            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 948                if !user.connected_once {
 949                    self.peer.send(connection_id, proto::ShowContacts {})?;
 950                    self.app_state
 951                        .db
 952                        .set_user_connected_once(user.id, true)
 953                        .await?;
 954                }
 955
 956                update_user_plan(session).await?;
 957
 958                let contacts = self.app_state.db.get_contacts(user.id).await?;
 959
 960                {
 961                    let mut pool = self.connection_pool.lock();
 962                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 963                    self.peer.send(
 964                        connection_id,
 965                        build_initial_contacts_update(contacts, &pool),
 966                    )?;
 967                }
 968
 969                if should_auto_subscribe_to_channels(zed_version) {
 970                    subscribe_user_to_channels(user.id, session).await?;
 971                }
 972
 973                if let Some(incoming_call) =
 974                    self.app_state.db.incoming_call_for_user(user.id).await?
 975                {
 976                    self.peer.send(connection_id, incoming_call)?;
 977                }
 978
 979                update_user_contacts(user.id, session).await?;
 980            }
 981        }
 982
 983        Ok(())
 984    }
 985
 986    pub async fn invite_code_redeemed(
 987        self: &Arc<Self>,
 988        inviter_id: UserId,
 989        invitee_id: UserId,
 990    ) -> Result<()> {
 991        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 992            if let Some(code) = &user.invite_code {
 993                let pool = self.connection_pool.lock();
 994                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 995                for connection_id in pool.user_connection_ids(inviter_id) {
 996                    self.peer.send(
 997                        connection_id,
 998                        proto::UpdateContacts {
 999                            contacts: vec![invitee_contact.clone()],
1000                            ..Default::default()
1001                        },
1002                    )?;
1003                    self.peer.send(
1004                        connection_id,
1005                        proto::UpdateInviteInfo {
1006                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1007                            count: user.invite_count as u32,
1008                        },
1009                    )?;
1010                }
1011            }
1012        }
1013        Ok(())
1014    }
1015
1016    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1017        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1018            if let Some(invite_code) = &user.invite_code {
1019                let pool = self.connection_pool.lock();
1020                for connection_id in pool.user_connection_ids(user_id) {
1021                    self.peer.send(
1022                        connection_id,
1023                        proto::UpdateInviteInfo {
1024                            url: format!(
1025                                "{}{}",
1026                                self.app_state.config.invite_link_prefix, invite_code
1027                            ),
1028                            count: user.invite_count as u32,
1029                        },
1030                    )?;
1031                }
1032            }
1033        }
1034        Ok(())
1035    }
1036
1037    pub async fn update_plan_for_user(
1038        self: &Arc<Self>,
1039        user_id: UserId,
1040        update_user_plan: proto::UpdateUserPlan,
1041    ) -> Result<()> {
1042        let pool = self.connection_pool.lock();
1043        for connection_id in pool.user_connection_ids(user_id) {
1044            self.peer
1045                .send(connection_id, update_user_plan.clone())
1046                .trace_err();
1047        }
1048
1049        Ok(())
1050    }
1051
1052    /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1053    /// message on the Collab server.
1054    ///
1055    /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1056    pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1057        let user = self
1058            .app_state
1059            .db
1060            .get_user_by_id(user_id)
1061            .await?
1062            .context("user not found")?;
1063
1064        let update_user_plan = make_update_user_plan_message(
1065            &user,
1066            user.admin,
1067            &self.app_state.db,
1068            self.app_state.llm_db.clone(),
1069        )
1070        .await?;
1071
1072        self.update_plan_for_user(user_id, update_user_plan).await
1073    }
1074
1075    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1076        let pool = self.connection_pool.lock();
1077        for connection_id in pool.user_connection_ids(user_id) {
1078            self.peer
1079                .send(connection_id, proto::RefreshLlmToken {})
1080                .trace_err();
1081        }
1082    }
1083
1084    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1085        ServerSnapshot {
1086            connection_pool: ConnectionPoolGuard {
1087                guard: self.connection_pool.lock(),
1088                _not_send: PhantomData,
1089            },
1090            peer: &self.peer,
1091        }
1092    }
1093}
1094
1095impl Deref for ConnectionPoolGuard<'_> {
1096    type Target = ConnectionPool;
1097
1098    fn deref(&self) -> &Self::Target {
1099        &self.guard
1100    }
1101}
1102
1103impl DerefMut for ConnectionPoolGuard<'_> {
1104    fn deref_mut(&mut self) -> &mut Self::Target {
1105        &mut self.guard
1106    }
1107}
1108
1109impl Drop for ConnectionPoolGuard<'_> {
1110    fn drop(&mut self) {
1111        #[cfg(test)]
1112        self.check_invariants();
1113    }
1114}
1115
1116fn broadcast<F>(
1117    sender_id: Option<ConnectionId>,
1118    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1119    mut f: F,
1120) where
1121    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1122{
1123    for receiver_id in receiver_ids {
1124        if Some(receiver_id) != sender_id {
1125            if let Err(error) = f(receiver_id) {
1126                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1127            }
1128        }
1129    }
1130}
1131
1132pub struct ProtocolVersion(u32);
1133
1134impl Header for ProtocolVersion {
1135    fn name() -> &'static HeaderName {
1136        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1137        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1138    }
1139
1140    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1141    where
1142        Self: Sized,
1143        I: Iterator<Item = &'i axum::http::HeaderValue>,
1144    {
1145        let version = values
1146            .next()
1147            .ok_or_else(axum::headers::Error::invalid)?
1148            .to_str()
1149            .map_err(|_| axum::headers::Error::invalid())?
1150            .parse()
1151            .map_err(|_| axum::headers::Error::invalid())?;
1152        Ok(Self(version))
1153    }
1154
1155    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1156        values.extend([self.0.to_string().parse().unwrap()]);
1157    }
1158}
1159
1160pub struct AppVersionHeader(SemanticVersion);
1161impl Header for AppVersionHeader {
1162    fn name() -> &'static HeaderName {
1163        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1164        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1165    }
1166
1167    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1168    where
1169        Self: Sized,
1170        I: Iterator<Item = &'i axum::http::HeaderValue>,
1171    {
1172        let version = values
1173            .next()
1174            .ok_or_else(axum::headers::Error::invalid)?
1175            .to_str()
1176            .map_err(|_| axum::headers::Error::invalid())?
1177            .parse()
1178            .map_err(|_| axum::headers::Error::invalid())?;
1179        Ok(Self(version))
1180    }
1181
1182    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1183        values.extend([self.0.to_string().parse().unwrap()]);
1184    }
1185}
1186
1187pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1188    Router::new()
1189        .route("/rpc", get(handle_websocket_request))
1190        .layer(
1191            ServiceBuilder::new()
1192                .layer(Extension(server.app_state.clone()))
1193                .layer(middleware::from_fn(auth::validate_header)),
1194        )
1195        .route("/metrics", get(handle_metrics))
1196        .layer(Extension(server))
1197}
1198
1199pub async fn handle_websocket_request(
1200    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1201    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1202    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1203    Extension(server): Extension<Arc<Server>>,
1204    Extension(principal): Extension<Principal>,
1205    user_agent: Option<TypedHeader<UserAgent>>,
1206    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1207    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1208    ws: WebSocketUpgrade,
1209) -> axum::response::Response {
1210    if protocol_version != rpc::PROTOCOL_VERSION {
1211        return (
1212            StatusCode::UPGRADE_REQUIRED,
1213            "client must be upgraded".to_string(),
1214        )
1215            .into_response();
1216    }
1217
1218    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1219        return (
1220            StatusCode::UPGRADE_REQUIRED,
1221            "no version header found".to_string(),
1222        )
1223            .into_response();
1224    };
1225
1226    if !version.can_collaborate() {
1227        return (
1228            StatusCode::UPGRADE_REQUIRED,
1229            "client must be upgraded".to_string(),
1230        )
1231            .into_response();
1232    }
1233
1234    let socket_address = socket_address.to_string();
1235
1236    // Acquire connection guard before WebSocket upgrade
1237    let connection_guard = match ConnectionGuard::try_acquire() {
1238        Ok(guard) => guard,
1239        Err(()) => {
1240            return (
1241                StatusCode::SERVICE_UNAVAILABLE,
1242                "Too many concurrent connections",
1243            )
1244                .into_response();
1245        }
1246    };
1247
1248    ws.on_upgrade(move |socket| {
1249        let socket = socket
1250            .map_ok(to_tungstenite_message)
1251            .err_into()
1252            .with(|message| async move { to_axum_message(message) });
1253        let connection = Connection::new(Box::pin(socket));
1254        async move {
1255            server
1256                .handle_connection(
1257                    connection,
1258                    socket_address,
1259                    principal,
1260                    version,
1261                    user_agent.map(|header| header.to_string()),
1262                    country_code_header.map(|header| header.to_string()),
1263                    system_id_header.map(|header| header.to_string()),
1264                    None,
1265                    Executor::Production,
1266                    Some(connection_guard),
1267                )
1268                .await;
1269        }
1270    })
1271}
1272
1273pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1274    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1275    let connections_metric = CONNECTIONS_METRIC
1276        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1277
1278    let connections = server
1279        .connection_pool
1280        .lock()
1281        .connections()
1282        .filter(|connection| !connection.admin)
1283        .count();
1284    connections_metric.set(connections as _);
1285
1286    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1287    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1288        register_int_gauge!(
1289            "shared_projects",
1290            "number of open projects with one or more guests"
1291        )
1292        .unwrap()
1293    });
1294
1295    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1296    shared_projects_metric.set(shared_projects as _);
1297
1298    let encoder = prometheus::TextEncoder::new();
1299    let metric_families = prometheus::gather();
1300    let encoded_metrics = encoder
1301        .encode_to_string(&metric_families)
1302        .map_err(|err| anyhow!("{err}"))?;
1303    Ok(encoded_metrics)
1304}
1305
1306#[instrument(err, skip(executor))]
1307async fn connection_lost(
1308    session: Session,
1309    mut teardown: watch::Receiver<bool>,
1310    executor: Executor,
1311) -> Result<()> {
1312    session.peer.disconnect(session.connection_id);
1313    session
1314        .connection_pool()
1315        .await
1316        .remove_connection(session.connection_id)?;
1317
1318    session
1319        .db()
1320        .await
1321        .connection_lost(session.connection_id)
1322        .await
1323        .trace_err();
1324
1325    futures::select_biased! {
1326        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1327
1328            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1329            leave_room_for_session(&session, session.connection_id).await.trace_err();
1330            leave_channel_buffers_for_session(&session)
1331                .await
1332                .trace_err();
1333
1334            if !session
1335                .connection_pool()
1336                .await
1337                .is_user_online(session.user_id())
1338            {
1339                let db = session.db().await;
1340                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1341                    room_updated(&room, &session.peer);
1342                }
1343            }
1344
1345            update_user_contacts(session.user_id(), &session).await?;
1346        },
1347        _ = teardown.changed().fuse() => {}
1348    }
1349
1350    Ok(())
1351}
1352
1353/// Acknowledges a ping from a client, used to keep the connection alive.
1354async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1355    response.send(proto::Ack {})?;
1356    Ok(())
1357}
1358
1359/// Creates a new room for calling (outside of channels)
1360async fn create_room(
1361    _request: proto::CreateRoom,
1362    response: Response<proto::CreateRoom>,
1363    session: Session,
1364) -> Result<()> {
1365    let livekit_room = nanoid::nanoid!(30);
1366
1367    let live_kit_connection_info = util::maybe!(async {
1368        let live_kit = session.app_state.livekit_client.as_ref();
1369        let live_kit = live_kit?;
1370        let user_id = session.user_id().to_string();
1371
1372        let token = live_kit
1373            .room_token(&livekit_room, &user_id.to_string())
1374            .trace_err()?;
1375
1376        Some(proto::LiveKitConnectionInfo {
1377            server_url: live_kit.url().into(),
1378            token,
1379            can_publish: true,
1380        })
1381    })
1382    .await;
1383
1384    let room = session
1385        .db()
1386        .await
1387        .create_room(session.user_id(), session.connection_id, &livekit_room)
1388        .await?;
1389
1390    response.send(proto::CreateRoomResponse {
1391        room: Some(room.clone()),
1392        live_kit_connection_info,
1393    })?;
1394
1395    update_user_contacts(session.user_id(), &session).await?;
1396    Ok(())
1397}
1398
1399/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1400async fn join_room(
1401    request: proto::JoinRoom,
1402    response: Response<proto::JoinRoom>,
1403    session: Session,
1404) -> Result<()> {
1405    let room_id = RoomId::from_proto(request.id);
1406
1407    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1408
1409    if let Some(channel_id) = channel_id {
1410        return join_channel_internal(channel_id, Box::new(response), session).await;
1411    }
1412
1413    let joined_room = {
1414        let room = session
1415            .db()
1416            .await
1417            .join_room(room_id, session.user_id(), session.connection_id)
1418            .await?;
1419        room_updated(&room.room, &session.peer);
1420        room.into_inner()
1421    };
1422
1423    for connection_id in session
1424        .connection_pool()
1425        .await
1426        .user_connection_ids(session.user_id())
1427    {
1428        session
1429            .peer
1430            .send(
1431                connection_id,
1432                proto::CallCanceled {
1433                    room_id: room_id.to_proto(),
1434                },
1435            )
1436            .trace_err();
1437    }
1438
1439    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1440    {
1441        live_kit
1442            .room_token(
1443                &joined_room.room.livekit_room,
1444                &session.user_id().to_string(),
1445            )
1446            .trace_err()
1447            .map(|token| proto::LiveKitConnectionInfo {
1448                server_url: live_kit.url().into(),
1449                token,
1450                can_publish: true,
1451            })
1452    } else {
1453        None
1454    };
1455
1456    response.send(proto::JoinRoomResponse {
1457        room: Some(joined_room.room),
1458        channel_id: None,
1459        live_kit_connection_info,
1460    })?;
1461
1462    update_user_contacts(session.user_id(), &session).await?;
1463    Ok(())
1464}
1465
1466/// Rejoin room is used to reconnect to a room after connection errors.
1467async fn rejoin_room(
1468    request: proto::RejoinRoom,
1469    response: Response<proto::RejoinRoom>,
1470    session: Session,
1471) -> Result<()> {
1472    let room;
1473    let channel;
1474    {
1475        let mut rejoined_room = session
1476            .db()
1477            .await
1478            .rejoin_room(request, session.user_id(), session.connection_id)
1479            .await?;
1480
1481        response.send(proto::RejoinRoomResponse {
1482            room: Some(rejoined_room.room.clone()),
1483            reshared_projects: rejoined_room
1484                .reshared_projects
1485                .iter()
1486                .map(|project| proto::ResharedProject {
1487                    id: project.id.to_proto(),
1488                    collaborators: project
1489                        .collaborators
1490                        .iter()
1491                        .map(|collaborator| collaborator.to_proto())
1492                        .collect(),
1493                })
1494                .collect(),
1495            rejoined_projects: rejoined_room
1496                .rejoined_projects
1497                .iter()
1498                .map(|rejoined_project| rejoined_project.to_proto())
1499                .collect(),
1500        })?;
1501        room_updated(&rejoined_room.room, &session.peer);
1502
1503        for project in &rejoined_room.reshared_projects {
1504            for collaborator in &project.collaborators {
1505                session
1506                    .peer
1507                    .send(
1508                        collaborator.connection_id,
1509                        proto::UpdateProjectCollaborator {
1510                            project_id: project.id.to_proto(),
1511                            old_peer_id: Some(project.old_connection_id.into()),
1512                            new_peer_id: Some(session.connection_id.into()),
1513                        },
1514                    )
1515                    .trace_err();
1516            }
1517
1518            broadcast(
1519                Some(session.connection_id),
1520                project
1521                    .collaborators
1522                    .iter()
1523                    .map(|collaborator| collaborator.connection_id),
1524                |connection_id| {
1525                    session.peer.forward_send(
1526                        session.connection_id,
1527                        connection_id,
1528                        proto::UpdateProject {
1529                            project_id: project.id.to_proto(),
1530                            worktrees: project.worktrees.clone(),
1531                        },
1532                    )
1533                },
1534            );
1535        }
1536
1537        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1538
1539        let rejoined_room = rejoined_room.into_inner();
1540
1541        room = rejoined_room.room;
1542        channel = rejoined_room.channel;
1543    }
1544
1545    if let Some(channel) = channel {
1546        channel_updated(
1547            &channel,
1548            &room,
1549            &session.peer,
1550            &*session.connection_pool().await,
1551        );
1552    }
1553
1554    update_user_contacts(session.user_id(), &session).await?;
1555    Ok(())
1556}
1557
1558fn notify_rejoined_projects(
1559    rejoined_projects: &mut Vec<RejoinedProject>,
1560    session: &Session,
1561) -> Result<()> {
1562    for project in rejoined_projects.iter() {
1563        for collaborator in &project.collaborators {
1564            session
1565                .peer
1566                .send(
1567                    collaborator.connection_id,
1568                    proto::UpdateProjectCollaborator {
1569                        project_id: project.id.to_proto(),
1570                        old_peer_id: Some(project.old_connection_id.into()),
1571                        new_peer_id: Some(session.connection_id.into()),
1572                    },
1573                )
1574                .trace_err();
1575        }
1576    }
1577
1578    for project in rejoined_projects {
1579        for worktree in mem::take(&mut project.worktrees) {
1580            // Stream this worktree's entries.
1581            let message = proto::UpdateWorktree {
1582                project_id: project.id.to_proto(),
1583                worktree_id: worktree.id,
1584                abs_path: worktree.abs_path.clone(),
1585                root_name: worktree.root_name,
1586                updated_entries: worktree.updated_entries,
1587                removed_entries: worktree.removed_entries,
1588                scan_id: worktree.scan_id,
1589                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1590                updated_repositories: worktree.updated_repositories,
1591                removed_repositories: worktree.removed_repositories,
1592            };
1593            for update in proto::split_worktree_update(message) {
1594                session.peer.send(session.connection_id, update)?;
1595            }
1596
1597            // Stream this worktree's diagnostics.
1598            for summary in worktree.diagnostic_summaries {
1599                session.peer.send(
1600                    session.connection_id,
1601                    proto::UpdateDiagnosticSummary {
1602                        project_id: project.id.to_proto(),
1603                        worktree_id: worktree.id,
1604                        summary: Some(summary),
1605                    },
1606                )?;
1607            }
1608
1609            for settings_file in worktree.settings_files {
1610                session.peer.send(
1611                    session.connection_id,
1612                    proto::UpdateWorktreeSettings {
1613                        project_id: project.id.to_proto(),
1614                        worktree_id: worktree.id,
1615                        path: settings_file.path,
1616                        content: Some(settings_file.content),
1617                        kind: Some(settings_file.kind.to_proto().into()),
1618                    },
1619                )?;
1620            }
1621        }
1622
1623        for repository in mem::take(&mut project.updated_repositories) {
1624            for update in split_repository_update(repository) {
1625                session.peer.send(session.connection_id, update)?;
1626            }
1627        }
1628
1629        for id in mem::take(&mut project.removed_repositories) {
1630            session.peer.send(
1631                session.connection_id,
1632                proto::RemoveRepository {
1633                    project_id: project.id.to_proto(),
1634                    id,
1635                },
1636            )?;
1637        }
1638    }
1639
1640    Ok(())
1641}
1642
1643/// leave room disconnects from the room.
1644async fn leave_room(
1645    _: proto::LeaveRoom,
1646    response: Response<proto::LeaveRoom>,
1647    session: Session,
1648) -> Result<()> {
1649    leave_room_for_session(&session, session.connection_id).await?;
1650    response.send(proto::Ack {})?;
1651    Ok(())
1652}
1653
1654/// Updates the permissions of someone else in the room.
1655async fn set_room_participant_role(
1656    request: proto::SetRoomParticipantRole,
1657    response: Response<proto::SetRoomParticipantRole>,
1658    session: Session,
1659) -> Result<()> {
1660    let user_id = UserId::from_proto(request.user_id);
1661    let role = ChannelRole::from(request.role());
1662
1663    let (livekit_room, can_publish) = {
1664        let room = session
1665            .db()
1666            .await
1667            .set_room_participant_role(
1668                session.user_id(),
1669                RoomId::from_proto(request.room_id),
1670                user_id,
1671                role,
1672            )
1673            .await?;
1674
1675        let livekit_room = room.livekit_room.clone();
1676        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1677        room_updated(&room, &session.peer);
1678        (livekit_room, can_publish)
1679    };
1680
1681    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1682        live_kit
1683            .update_participant(
1684                livekit_room.clone(),
1685                request.user_id.to_string(),
1686                livekit_api::proto::ParticipantPermission {
1687                    can_subscribe: true,
1688                    can_publish,
1689                    can_publish_data: can_publish,
1690                    hidden: false,
1691                    recorder: false,
1692                },
1693            )
1694            .await
1695            .trace_err();
1696    }
1697
1698    response.send(proto::Ack {})?;
1699    Ok(())
1700}
1701
1702/// Call someone else into the current room
1703async fn call(
1704    request: proto::Call,
1705    response: Response<proto::Call>,
1706    session: Session,
1707) -> Result<()> {
1708    let room_id = RoomId::from_proto(request.room_id);
1709    let calling_user_id = session.user_id();
1710    let calling_connection_id = session.connection_id;
1711    let called_user_id = UserId::from_proto(request.called_user_id);
1712    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1713    if !session
1714        .db()
1715        .await
1716        .has_contact(calling_user_id, called_user_id)
1717        .await?
1718    {
1719        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1720    }
1721
1722    let incoming_call = {
1723        let (room, incoming_call) = &mut *session
1724            .db()
1725            .await
1726            .call(
1727                room_id,
1728                calling_user_id,
1729                calling_connection_id,
1730                called_user_id,
1731                initial_project_id,
1732            )
1733            .await?;
1734        room_updated(room, &session.peer);
1735        mem::take(incoming_call)
1736    };
1737    update_user_contacts(called_user_id, &session).await?;
1738
1739    let mut calls = session
1740        .connection_pool()
1741        .await
1742        .user_connection_ids(called_user_id)
1743        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1744        .collect::<FuturesUnordered<_>>();
1745
1746    while let Some(call_response) = calls.next().await {
1747        match call_response.as_ref() {
1748            Ok(_) => {
1749                response.send(proto::Ack {})?;
1750                return Ok(());
1751            }
1752            Err(_) => {
1753                call_response.trace_err();
1754            }
1755        }
1756    }
1757
1758    {
1759        let room = session
1760            .db()
1761            .await
1762            .call_failed(room_id, called_user_id)
1763            .await?;
1764        room_updated(&room, &session.peer);
1765    }
1766    update_user_contacts(called_user_id, &session).await?;
1767
1768    Err(anyhow!("failed to ring user"))?
1769}
1770
1771/// Cancel an outgoing call.
1772async fn cancel_call(
1773    request: proto::CancelCall,
1774    response: Response<proto::CancelCall>,
1775    session: Session,
1776) -> Result<()> {
1777    let called_user_id = UserId::from_proto(request.called_user_id);
1778    let room_id = RoomId::from_proto(request.room_id);
1779    {
1780        let room = session
1781            .db()
1782            .await
1783            .cancel_call(room_id, session.connection_id, called_user_id)
1784            .await?;
1785        room_updated(&room, &session.peer);
1786    }
1787
1788    for connection_id in session
1789        .connection_pool()
1790        .await
1791        .user_connection_ids(called_user_id)
1792    {
1793        session
1794            .peer
1795            .send(
1796                connection_id,
1797                proto::CallCanceled {
1798                    room_id: room_id.to_proto(),
1799                },
1800            )
1801            .trace_err();
1802    }
1803    response.send(proto::Ack {})?;
1804
1805    update_user_contacts(called_user_id, &session).await?;
1806    Ok(())
1807}
1808
1809/// Decline an incoming call.
1810async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1811    let room_id = RoomId::from_proto(message.room_id);
1812    {
1813        let room = session
1814            .db()
1815            .await
1816            .decline_call(Some(room_id), session.user_id())
1817            .await?
1818            .context("declining call")?;
1819        room_updated(&room, &session.peer);
1820    }
1821
1822    for connection_id in session
1823        .connection_pool()
1824        .await
1825        .user_connection_ids(session.user_id())
1826    {
1827        session
1828            .peer
1829            .send(
1830                connection_id,
1831                proto::CallCanceled {
1832                    room_id: room_id.to_proto(),
1833                },
1834            )
1835            .trace_err();
1836    }
1837    update_user_contacts(session.user_id(), &session).await?;
1838    Ok(())
1839}
1840
1841/// Updates other participants in the room with your current location.
1842async fn update_participant_location(
1843    request: proto::UpdateParticipantLocation,
1844    response: Response<proto::UpdateParticipantLocation>,
1845    session: Session,
1846) -> Result<()> {
1847    let room_id = RoomId::from_proto(request.room_id);
1848    let location = request.location.context("invalid location")?;
1849
1850    let db = session.db().await;
1851    let room = db
1852        .update_room_participant_location(room_id, session.connection_id, location)
1853        .await?;
1854
1855    room_updated(&room, &session.peer);
1856    response.send(proto::Ack {})?;
1857    Ok(())
1858}
1859
1860/// Share a project into the room.
1861async fn share_project(
1862    request: proto::ShareProject,
1863    response: Response<proto::ShareProject>,
1864    session: Session,
1865) -> Result<()> {
1866    let (project_id, room) = &*session
1867        .db()
1868        .await
1869        .share_project(
1870            RoomId::from_proto(request.room_id),
1871            session.connection_id,
1872            &request.worktrees,
1873            request.is_ssh_project,
1874        )
1875        .await?;
1876    response.send(proto::ShareProjectResponse {
1877        project_id: project_id.to_proto(),
1878    })?;
1879    room_updated(room, &session.peer);
1880
1881    Ok(())
1882}
1883
1884/// Unshare a project from the room.
1885async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1886    let project_id = ProjectId::from_proto(message.project_id);
1887    unshare_project_internal(project_id, session.connection_id, &session).await
1888}
1889
1890async fn unshare_project_internal(
1891    project_id: ProjectId,
1892    connection_id: ConnectionId,
1893    session: &Session,
1894) -> Result<()> {
1895    let delete = {
1896        let room_guard = session
1897            .db()
1898            .await
1899            .unshare_project(project_id, connection_id)
1900            .await?;
1901
1902        let (delete, room, guest_connection_ids) = &*room_guard;
1903
1904        let message = proto::UnshareProject {
1905            project_id: project_id.to_proto(),
1906        };
1907
1908        broadcast(
1909            Some(connection_id),
1910            guest_connection_ids.iter().copied(),
1911            |conn_id| session.peer.send(conn_id, message.clone()),
1912        );
1913        if let Some(room) = room {
1914            room_updated(room, &session.peer);
1915        }
1916
1917        *delete
1918    };
1919
1920    if delete {
1921        let db = session.db().await;
1922        db.delete_project(project_id).await?;
1923    }
1924
1925    Ok(())
1926}
1927
1928/// Join someone elses shared project.
1929async fn join_project(
1930    request: proto::JoinProject,
1931    response: Response<proto::JoinProject>,
1932    session: Session,
1933) -> Result<()> {
1934    let project_id = ProjectId::from_proto(request.project_id);
1935
1936    tracing::info!(%project_id, "join project");
1937
1938    let db = session.db().await;
1939    let (project, replica_id) = &mut *db
1940        .join_project(
1941            project_id,
1942            session.connection_id,
1943            session.user_id(),
1944            request.committer_name.clone(),
1945            request.committer_email.clone(),
1946        )
1947        .await?;
1948    drop(db);
1949    tracing::info!(%project_id, "join remote project");
1950    let collaborators = project
1951        .collaborators
1952        .iter()
1953        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1954        .map(|collaborator| collaborator.to_proto())
1955        .collect::<Vec<_>>();
1956    let project_id = project.id;
1957    let guest_user_id = session.user_id();
1958
1959    let worktrees = project
1960        .worktrees
1961        .iter()
1962        .map(|(id, worktree)| proto::WorktreeMetadata {
1963            id: *id,
1964            root_name: worktree.root_name.clone(),
1965            visible: worktree.visible,
1966            abs_path: worktree.abs_path.clone(),
1967        })
1968        .collect::<Vec<_>>();
1969
1970    let add_project_collaborator = proto::AddProjectCollaborator {
1971        project_id: project_id.to_proto(),
1972        collaborator: Some(proto::Collaborator {
1973            peer_id: Some(session.connection_id.into()),
1974            replica_id: replica_id.0 as u32,
1975            user_id: guest_user_id.to_proto(),
1976            is_host: false,
1977            committer_name: request.committer_name.clone(),
1978            committer_email: request.committer_email.clone(),
1979        }),
1980    };
1981
1982    for collaborator in &collaborators {
1983        session
1984            .peer
1985            .send(
1986                collaborator.peer_id.unwrap().into(),
1987                add_project_collaborator.clone(),
1988            )
1989            .trace_err();
1990    }
1991
1992    // First, we send the metadata associated with each worktree.
1993    response.send(proto::JoinProjectResponse {
1994        project_id: project.id.0 as u64,
1995        worktrees: worktrees.clone(),
1996        replica_id: replica_id.0 as u32,
1997        collaborators: collaborators.clone(),
1998        language_servers: project.language_servers.clone(),
1999        role: project.role.into(),
2000    })?;
2001
2002    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2003        // Stream this worktree's entries.
2004        let message = proto::UpdateWorktree {
2005            project_id: project_id.to_proto(),
2006            worktree_id,
2007            abs_path: worktree.abs_path.clone(),
2008            root_name: worktree.root_name,
2009            updated_entries: worktree.entries,
2010            removed_entries: Default::default(),
2011            scan_id: worktree.scan_id,
2012            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2013            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2014            removed_repositories: Default::default(),
2015        };
2016        for update in proto::split_worktree_update(message) {
2017            session.peer.send(session.connection_id, update.clone())?;
2018        }
2019
2020        // Stream this worktree's diagnostics.
2021        for summary in worktree.diagnostic_summaries {
2022            session.peer.send(
2023                session.connection_id,
2024                proto::UpdateDiagnosticSummary {
2025                    project_id: project_id.to_proto(),
2026                    worktree_id: worktree.id,
2027                    summary: Some(summary),
2028                },
2029            )?;
2030        }
2031
2032        for settings_file in worktree.settings_files {
2033            session.peer.send(
2034                session.connection_id,
2035                proto::UpdateWorktreeSettings {
2036                    project_id: project_id.to_proto(),
2037                    worktree_id: worktree.id,
2038                    path: settings_file.path,
2039                    content: Some(settings_file.content),
2040                    kind: Some(settings_file.kind.to_proto() as i32),
2041                },
2042            )?;
2043        }
2044    }
2045
2046    for repository in mem::take(&mut project.repositories) {
2047        for update in split_repository_update(repository) {
2048            session.peer.send(session.connection_id, update)?;
2049        }
2050    }
2051
2052    for language_server in &project.language_servers {
2053        session.peer.send(
2054            session.connection_id,
2055            proto::UpdateLanguageServer {
2056                project_id: project_id.to_proto(),
2057                server_name: Some(language_server.name.clone()),
2058                language_server_id: language_server.id,
2059                variant: Some(
2060                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2061                        proto::LspDiskBasedDiagnosticsUpdated {},
2062                    ),
2063                ),
2064            },
2065        )?;
2066    }
2067
2068    Ok(())
2069}
2070
2071/// Leave someone elses shared project.
2072async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2073    let sender_id = session.connection_id;
2074    let project_id = ProjectId::from_proto(request.project_id);
2075    let db = session.db().await;
2076
2077    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2078    tracing::info!(
2079        %project_id,
2080        "leave project"
2081    );
2082
2083    project_left(project, &session);
2084    if let Some(room) = room {
2085        room_updated(room, &session.peer);
2086    }
2087
2088    Ok(())
2089}
2090
2091/// Updates other participants with changes to the project
2092async fn update_project(
2093    request: proto::UpdateProject,
2094    response: Response<proto::UpdateProject>,
2095    session: Session,
2096) -> Result<()> {
2097    let project_id = ProjectId::from_proto(request.project_id);
2098    let (room, guest_connection_ids) = &*session
2099        .db()
2100        .await
2101        .update_project(project_id, session.connection_id, &request.worktrees)
2102        .await?;
2103    broadcast(
2104        Some(session.connection_id),
2105        guest_connection_ids.iter().copied(),
2106        |connection_id| {
2107            session
2108                .peer
2109                .forward_send(session.connection_id, connection_id, request.clone())
2110        },
2111    );
2112    if let Some(room) = room {
2113        room_updated(room, &session.peer);
2114    }
2115    response.send(proto::Ack {})?;
2116
2117    Ok(())
2118}
2119
2120/// Updates other participants with changes to the worktree
2121async fn update_worktree(
2122    request: proto::UpdateWorktree,
2123    response: Response<proto::UpdateWorktree>,
2124    session: Session,
2125) -> Result<()> {
2126    let guest_connection_ids = session
2127        .db()
2128        .await
2129        .update_worktree(&request, session.connection_id)
2130        .await?;
2131
2132    broadcast(
2133        Some(session.connection_id),
2134        guest_connection_ids.iter().copied(),
2135        |connection_id| {
2136            session
2137                .peer
2138                .forward_send(session.connection_id, connection_id, request.clone())
2139        },
2140    );
2141    response.send(proto::Ack {})?;
2142    Ok(())
2143}
2144
2145async fn update_repository(
2146    request: proto::UpdateRepository,
2147    response: Response<proto::UpdateRepository>,
2148    session: Session,
2149) -> Result<()> {
2150    let guest_connection_ids = session
2151        .db()
2152        .await
2153        .update_repository(&request, session.connection_id)
2154        .await?;
2155
2156    broadcast(
2157        Some(session.connection_id),
2158        guest_connection_ids.iter().copied(),
2159        |connection_id| {
2160            session
2161                .peer
2162                .forward_send(session.connection_id, connection_id, request.clone())
2163        },
2164    );
2165    response.send(proto::Ack {})?;
2166    Ok(())
2167}
2168
2169async fn remove_repository(
2170    request: proto::RemoveRepository,
2171    response: Response<proto::RemoveRepository>,
2172    session: Session,
2173) -> Result<()> {
2174    let guest_connection_ids = session
2175        .db()
2176        .await
2177        .remove_repository(&request, session.connection_id)
2178        .await?;
2179
2180    broadcast(
2181        Some(session.connection_id),
2182        guest_connection_ids.iter().copied(),
2183        |connection_id| {
2184            session
2185                .peer
2186                .forward_send(session.connection_id, connection_id, request.clone())
2187        },
2188    );
2189    response.send(proto::Ack {})?;
2190    Ok(())
2191}
2192
2193/// Updates other participants with changes to the diagnostics
2194async fn update_diagnostic_summary(
2195    message: proto::UpdateDiagnosticSummary,
2196    session: Session,
2197) -> Result<()> {
2198    let guest_connection_ids = session
2199        .db()
2200        .await
2201        .update_diagnostic_summary(&message, session.connection_id)
2202        .await?;
2203
2204    broadcast(
2205        Some(session.connection_id),
2206        guest_connection_ids.iter().copied(),
2207        |connection_id| {
2208            session
2209                .peer
2210                .forward_send(session.connection_id, connection_id, message.clone())
2211        },
2212    );
2213
2214    Ok(())
2215}
2216
2217/// Updates other participants with changes to the worktree settings
2218async fn update_worktree_settings(
2219    message: proto::UpdateWorktreeSettings,
2220    session: Session,
2221) -> Result<()> {
2222    let guest_connection_ids = session
2223        .db()
2224        .await
2225        .update_worktree_settings(&message, session.connection_id)
2226        .await?;
2227
2228    broadcast(
2229        Some(session.connection_id),
2230        guest_connection_ids.iter().copied(),
2231        |connection_id| {
2232            session
2233                .peer
2234                .forward_send(session.connection_id, connection_id, message.clone())
2235        },
2236    );
2237
2238    Ok(())
2239}
2240
2241/// Notify other participants that a language server has started.
2242async fn start_language_server(
2243    request: proto::StartLanguageServer,
2244    session: Session,
2245) -> Result<()> {
2246    let guest_connection_ids = session
2247        .db()
2248        .await
2249        .start_language_server(&request, session.connection_id)
2250        .await?;
2251
2252    broadcast(
2253        Some(session.connection_id),
2254        guest_connection_ids.iter().copied(),
2255        |connection_id| {
2256            session
2257                .peer
2258                .forward_send(session.connection_id, connection_id, request.clone())
2259        },
2260    );
2261    Ok(())
2262}
2263
2264/// Notify other participants that a language server has changed.
2265async fn update_language_server(
2266    request: proto::UpdateLanguageServer,
2267    session: Session,
2268) -> Result<()> {
2269    let project_id = ProjectId::from_proto(request.project_id);
2270    let project_connection_ids = session
2271        .db()
2272        .await
2273        .project_connection_ids(project_id, session.connection_id, true)
2274        .await?;
2275    broadcast(
2276        Some(session.connection_id),
2277        project_connection_ids.iter().copied(),
2278        |connection_id| {
2279            session
2280                .peer
2281                .forward_send(session.connection_id, connection_id, request.clone())
2282        },
2283    );
2284    Ok(())
2285}
2286
2287/// forward a project request to the host. These requests should be read only
2288/// as guests are allowed to send them.
2289async fn forward_read_only_project_request<T>(
2290    request: T,
2291    response: Response<T>,
2292    session: Session,
2293) -> Result<()>
2294where
2295    T: EntityMessage + RequestMessage,
2296{
2297    let project_id = ProjectId::from_proto(request.remote_entity_id());
2298    let host_connection_id = session
2299        .db()
2300        .await
2301        .host_for_read_only_project_request(project_id, session.connection_id)
2302        .await?;
2303    let payload = session
2304        .peer
2305        .forward_request(session.connection_id, host_connection_id, request)
2306        .await?;
2307    response.send(payload)?;
2308    Ok(())
2309}
2310
2311/// forward a project request to the host. These requests are disallowed
2312/// for guests.
2313async fn forward_mutating_project_request<T>(
2314    request: T,
2315    response: Response<T>,
2316    session: Session,
2317) -> Result<()>
2318where
2319    T: EntityMessage + RequestMessage,
2320{
2321    let project_id = ProjectId::from_proto(request.remote_entity_id());
2322
2323    let host_connection_id = session
2324        .db()
2325        .await
2326        .host_for_mutating_project_request(project_id, session.connection_id)
2327        .await?;
2328    let payload = session
2329        .peer
2330        .forward_request(session.connection_id, host_connection_id, request)
2331        .await?;
2332    response.send(payload)?;
2333    Ok(())
2334}
2335
2336async fn multi_lsp_query(
2337    request: MultiLspQuery,
2338    response: Response<MultiLspQuery>,
2339    session: Session,
2340) -> Result<()> {
2341    tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2342    tracing::info!("multi_lsp_query message received");
2343    forward_mutating_project_request(request, response, session).await
2344}
2345
2346/// Notify other participants that a new buffer has been created
2347async fn create_buffer_for_peer(
2348    request: proto::CreateBufferForPeer,
2349    session: Session,
2350) -> Result<()> {
2351    session
2352        .db()
2353        .await
2354        .check_user_is_project_host(
2355            ProjectId::from_proto(request.project_id),
2356            session.connection_id,
2357        )
2358        .await?;
2359    let peer_id = request.peer_id.context("invalid peer id")?;
2360    session
2361        .peer
2362        .forward_send(session.connection_id, peer_id.into(), request)?;
2363    Ok(())
2364}
2365
2366/// Notify other participants that a buffer has been updated. This is
2367/// allowed for guests as long as the update is limited to selections.
2368async fn update_buffer(
2369    request: proto::UpdateBuffer,
2370    response: Response<proto::UpdateBuffer>,
2371    session: Session,
2372) -> Result<()> {
2373    let project_id = ProjectId::from_proto(request.project_id);
2374    let mut capability = Capability::ReadOnly;
2375
2376    for op in request.operations.iter() {
2377        match op.variant {
2378            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2379            Some(_) => capability = Capability::ReadWrite,
2380        }
2381    }
2382
2383    let host = {
2384        let guard = session
2385            .db()
2386            .await
2387            .connections_for_buffer_update(project_id, session.connection_id, capability)
2388            .await?;
2389
2390        let (host, guests) = &*guard;
2391
2392        broadcast(
2393            Some(session.connection_id),
2394            guests.clone(),
2395            |connection_id| {
2396                session
2397                    .peer
2398                    .forward_send(session.connection_id, connection_id, request.clone())
2399            },
2400        );
2401
2402        *host
2403    };
2404
2405    if host != session.connection_id {
2406        session
2407            .peer
2408            .forward_request(session.connection_id, host, request.clone())
2409            .await?;
2410    }
2411
2412    response.send(proto::Ack {})?;
2413    Ok(())
2414}
2415
2416async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2417    let project_id = ProjectId::from_proto(message.project_id);
2418
2419    let operation = message.operation.as_ref().context("invalid operation")?;
2420    let capability = match operation.variant.as_ref() {
2421        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2422            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2423                match buffer_op.variant {
2424                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2425                        Capability::ReadOnly
2426                    }
2427                    _ => Capability::ReadWrite,
2428                }
2429            } else {
2430                Capability::ReadWrite
2431            }
2432        }
2433        Some(_) => Capability::ReadWrite,
2434        None => Capability::ReadOnly,
2435    };
2436
2437    let guard = session
2438        .db()
2439        .await
2440        .connections_for_buffer_update(project_id, session.connection_id, capability)
2441        .await?;
2442
2443    let (host, guests) = &*guard;
2444
2445    broadcast(
2446        Some(session.connection_id),
2447        guests.iter().chain([host]).copied(),
2448        |connection_id| {
2449            session
2450                .peer
2451                .forward_send(session.connection_id, connection_id, message.clone())
2452        },
2453    );
2454
2455    Ok(())
2456}
2457
2458/// Notify other participants that a project has been updated.
2459async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2460    request: T,
2461    session: Session,
2462) -> Result<()> {
2463    let project_id = ProjectId::from_proto(request.remote_entity_id());
2464    let project_connection_ids = session
2465        .db()
2466        .await
2467        .project_connection_ids(project_id, session.connection_id, false)
2468        .await?;
2469
2470    broadcast(
2471        Some(session.connection_id),
2472        project_connection_ids.iter().copied(),
2473        |connection_id| {
2474            session
2475                .peer
2476                .forward_send(session.connection_id, connection_id, request.clone())
2477        },
2478    );
2479    Ok(())
2480}
2481
2482/// Start following another user in a call.
2483async fn follow(
2484    request: proto::Follow,
2485    response: Response<proto::Follow>,
2486    session: Session,
2487) -> Result<()> {
2488    let room_id = RoomId::from_proto(request.room_id);
2489    let project_id = request.project_id.map(ProjectId::from_proto);
2490    let leader_id = request.leader_id.context("invalid leader id")?.into();
2491    let follower_id = session.connection_id;
2492
2493    session
2494        .db()
2495        .await
2496        .check_room_participants(room_id, leader_id, session.connection_id)
2497        .await?;
2498
2499    let response_payload = session
2500        .peer
2501        .forward_request(session.connection_id, leader_id, request)
2502        .await?;
2503    response.send(response_payload)?;
2504
2505    if let Some(project_id) = project_id {
2506        let room = session
2507            .db()
2508            .await
2509            .follow(room_id, project_id, leader_id, follower_id)
2510            .await?;
2511        room_updated(&room, &session.peer);
2512    }
2513
2514    Ok(())
2515}
2516
2517/// Stop following another user in a call.
2518async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2519    let room_id = RoomId::from_proto(request.room_id);
2520    let project_id = request.project_id.map(ProjectId::from_proto);
2521    let leader_id = request.leader_id.context("invalid leader id")?.into();
2522    let follower_id = session.connection_id;
2523
2524    session
2525        .db()
2526        .await
2527        .check_room_participants(room_id, leader_id, session.connection_id)
2528        .await?;
2529
2530    session
2531        .peer
2532        .forward_send(session.connection_id, leader_id, request)?;
2533
2534    if let Some(project_id) = project_id {
2535        let room = session
2536            .db()
2537            .await
2538            .unfollow(room_id, project_id, leader_id, follower_id)
2539            .await?;
2540        room_updated(&room, &session.peer);
2541    }
2542
2543    Ok(())
2544}
2545
2546/// Notify everyone following you of your current location.
2547async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2548    let room_id = RoomId::from_proto(request.room_id);
2549    let database = session.db.lock().await;
2550
2551    let connection_ids = if let Some(project_id) = request.project_id {
2552        let project_id = ProjectId::from_proto(project_id);
2553        database
2554            .project_connection_ids(project_id, session.connection_id, true)
2555            .await?
2556    } else {
2557        database
2558            .room_connection_ids(room_id, session.connection_id)
2559            .await?
2560    };
2561
2562    // For now, don't send view update messages back to that view's current leader.
2563    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2564        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2565        _ => None,
2566    });
2567
2568    for connection_id in connection_ids.iter().cloned() {
2569        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2570            session
2571                .peer
2572                .forward_send(session.connection_id, connection_id, request.clone())?;
2573        }
2574    }
2575    Ok(())
2576}
2577
2578/// Get public data about users.
2579async fn get_users(
2580    request: proto::GetUsers,
2581    response: Response<proto::GetUsers>,
2582    session: Session,
2583) -> Result<()> {
2584    let user_ids = request
2585        .user_ids
2586        .into_iter()
2587        .map(UserId::from_proto)
2588        .collect();
2589    let users = session
2590        .db()
2591        .await
2592        .get_users_by_ids(user_ids)
2593        .await?
2594        .into_iter()
2595        .map(|user| proto::User {
2596            id: user.id.to_proto(),
2597            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2598            github_login: user.github_login,
2599            name: user.name,
2600        })
2601        .collect();
2602    response.send(proto::UsersResponse { users })?;
2603    Ok(())
2604}
2605
2606/// Search for users (to invite) buy Github login
2607async fn fuzzy_search_users(
2608    request: proto::FuzzySearchUsers,
2609    response: Response<proto::FuzzySearchUsers>,
2610    session: Session,
2611) -> Result<()> {
2612    let query = request.query;
2613    let users = match query.len() {
2614        0 => vec![],
2615        1 | 2 => session
2616            .db()
2617            .await
2618            .get_user_by_github_login(&query)
2619            .await?
2620            .into_iter()
2621            .collect(),
2622        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2623    };
2624    let users = users
2625        .into_iter()
2626        .filter(|user| user.id != session.user_id())
2627        .map(|user| proto::User {
2628            id: user.id.to_proto(),
2629            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2630            github_login: user.github_login,
2631            name: user.name,
2632        })
2633        .collect();
2634    response.send(proto::UsersResponse { users })?;
2635    Ok(())
2636}
2637
2638/// Send a contact request to another user.
2639async fn request_contact(
2640    request: proto::RequestContact,
2641    response: Response<proto::RequestContact>,
2642    session: Session,
2643) -> Result<()> {
2644    let requester_id = session.user_id();
2645    let responder_id = UserId::from_proto(request.responder_id);
2646    if requester_id == responder_id {
2647        return Err(anyhow!("cannot add yourself as a contact"))?;
2648    }
2649
2650    let notifications = session
2651        .db()
2652        .await
2653        .send_contact_request(requester_id, responder_id)
2654        .await?;
2655
2656    // Update outgoing contact requests of requester
2657    let mut update = proto::UpdateContacts::default();
2658    update.outgoing_requests.push(responder_id.to_proto());
2659    for connection_id in session
2660        .connection_pool()
2661        .await
2662        .user_connection_ids(requester_id)
2663    {
2664        session.peer.send(connection_id, update.clone())?;
2665    }
2666
2667    // Update incoming contact requests of responder
2668    let mut update = proto::UpdateContacts::default();
2669    update
2670        .incoming_requests
2671        .push(proto::IncomingContactRequest {
2672            requester_id: requester_id.to_proto(),
2673        });
2674    let connection_pool = session.connection_pool().await;
2675    for connection_id in connection_pool.user_connection_ids(responder_id) {
2676        session.peer.send(connection_id, update.clone())?;
2677    }
2678
2679    send_notifications(&connection_pool, &session.peer, notifications);
2680
2681    response.send(proto::Ack {})?;
2682    Ok(())
2683}
2684
2685/// Accept or decline a contact request
2686async fn respond_to_contact_request(
2687    request: proto::RespondToContactRequest,
2688    response: Response<proto::RespondToContactRequest>,
2689    session: Session,
2690) -> Result<()> {
2691    let responder_id = session.user_id();
2692    let requester_id = UserId::from_proto(request.requester_id);
2693    let db = session.db().await;
2694    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2695        db.dismiss_contact_notification(responder_id, requester_id)
2696            .await?;
2697    } else {
2698        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2699
2700        let notifications = db
2701            .respond_to_contact_request(responder_id, requester_id, accept)
2702            .await?;
2703        let requester_busy = db.is_user_busy(requester_id).await?;
2704        let responder_busy = db.is_user_busy(responder_id).await?;
2705
2706        let pool = session.connection_pool().await;
2707        // Update responder with new contact
2708        let mut update = proto::UpdateContacts::default();
2709        if accept {
2710            update
2711                .contacts
2712                .push(contact_for_user(requester_id, requester_busy, &pool));
2713        }
2714        update
2715            .remove_incoming_requests
2716            .push(requester_id.to_proto());
2717        for connection_id in pool.user_connection_ids(responder_id) {
2718            session.peer.send(connection_id, update.clone())?;
2719        }
2720
2721        // Update requester with new contact
2722        let mut update = proto::UpdateContacts::default();
2723        if accept {
2724            update
2725                .contacts
2726                .push(contact_for_user(responder_id, responder_busy, &pool));
2727        }
2728        update
2729            .remove_outgoing_requests
2730            .push(responder_id.to_proto());
2731
2732        for connection_id in pool.user_connection_ids(requester_id) {
2733            session.peer.send(connection_id, update.clone())?;
2734        }
2735
2736        send_notifications(&pool, &session.peer, notifications);
2737    }
2738
2739    response.send(proto::Ack {})?;
2740    Ok(())
2741}
2742
2743/// Remove a contact.
2744async fn remove_contact(
2745    request: proto::RemoveContact,
2746    response: Response<proto::RemoveContact>,
2747    session: Session,
2748) -> Result<()> {
2749    let requester_id = session.user_id();
2750    let responder_id = UserId::from_proto(request.user_id);
2751    let db = session.db().await;
2752    let (contact_accepted, deleted_notification_id) =
2753        db.remove_contact(requester_id, responder_id).await?;
2754
2755    let pool = session.connection_pool().await;
2756    // Update outgoing contact requests of requester
2757    let mut update = proto::UpdateContacts::default();
2758    if contact_accepted {
2759        update.remove_contacts.push(responder_id.to_proto());
2760    } else {
2761        update
2762            .remove_outgoing_requests
2763            .push(responder_id.to_proto());
2764    }
2765    for connection_id in pool.user_connection_ids(requester_id) {
2766        session.peer.send(connection_id, update.clone())?;
2767    }
2768
2769    // Update incoming contact requests of responder
2770    let mut update = proto::UpdateContacts::default();
2771    if contact_accepted {
2772        update.remove_contacts.push(requester_id.to_proto());
2773    } else {
2774        update
2775            .remove_incoming_requests
2776            .push(requester_id.to_proto());
2777    }
2778    for connection_id in pool.user_connection_ids(responder_id) {
2779        session.peer.send(connection_id, update.clone())?;
2780        if let Some(notification_id) = deleted_notification_id {
2781            session.peer.send(
2782                connection_id,
2783                proto::DeleteNotification {
2784                    notification_id: notification_id.to_proto(),
2785                },
2786            )?;
2787        }
2788    }
2789
2790    response.send(proto::Ack {})?;
2791    Ok(())
2792}
2793
2794fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2795    version.0.minor() < 139
2796}
2797
2798async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2799    if is_staff {
2800        return Ok(proto::Plan::ZedPro);
2801    }
2802
2803    let subscription = db.get_active_billing_subscription(user_id).await?;
2804    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2805
2806    let plan = if let Some(subscription_kind) = subscription_kind {
2807        match subscription_kind {
2808            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2809            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2810            SubscriptionKind::ZedFree => proto::Plan::Free,
2811        }
2812    } else {
2813        proto::Plan::Free
2814    };
2815
2816    Ok(plan)
2817}
2818
2819async fn make_update_user_plan_message(
2820    user: &User,
2821    is_staff: bool,
2822    db: &Arc<Database>,
2823    llm_db: Option<Arc<LlmDatabase>>,
2824) -> Result<proto::UpdateUserPlan> {
2825    let feature_flags = db.get_user_flags(user.id).await?;
2826    let plan = current_plan(db, user.id, is_staff).await?;
2827    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2828    let billing_preferences = db.get_billing_preferences(user.id).await?;
2829
2830    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2831        let subscription = db.get_active_billing_subscription(user.id).await?;
2832
2833        let subscription_period =
2834            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2835
2836        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2837            llm_db
2838                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2839                .await?
2840        } else {
2841            None
2842        };
2843
2844        (subscription_period, usage)
2845    } else {
2846        (None, None)
2847    };
2848
2849    let bypass_account_age_check = feature_flags
2850        .iter()
2851        .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2852    let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2853        && !bypass_account_age_check
2854        && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2855
2856    Ok(proto::UpdateUserPlan {
2857        plan: plan.into(),
2858        trial_started_at: billing_customer
2859            .as_ref()
2860            .and_then(|billing_customer| billing_customer.trial_started_at)
2861            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2862        is_usage_based_billing_enabled: if is_staff {
2863            Some(true)
2864        } else {
2865            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2866        },
2867        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2868            proto::SubscriptionPeriod {
2869                started_at: started_at.timestamp() as u64,
2870                ended_at: ended_at.timestamp() as u64,
2871            }
2872        }),
2873        account_too_young: Some(account_too_young),
2874        has_overdue_invoices: billing_customer
2875            .map(|billing_customer| billing_customer.has_overdue_invoices),
2876        usage: Some(
2877            usage
2878                .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2879                .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2880        ),
2881    })
2882}
2883
2884fn model_requests_limit(
2885    plan: cloud_llm_client::Plan,
2886    feature_flags: &Vec<String>,
2887) -> cloud_llm_client::UsageLimit {
2888    match plan.model_requests_limit() {
2889        cloud_llm_client::UsageLimit::Limited(limit) => {
2890            let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2891                && feature_flags
2892                    .iter()
2893                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2894            {
2895                1_000
2896            } else {
2897                limit
2898            };
2899
2900            cloud_llm_client::UsageLimit::Limited(limit)
2901        }
2902        cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2903    }
2904}
2905
2906fn subscription_usage_to_proto(
2907    plan: proto::Plan,
2908    usage: crate::llm::db::subscription_usage::Model,
2909    feature_flags: &Vec<String>,
2910) -> proto::SubscriptionUsage {
2911    let plan = match plan {
2912        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2913        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2914        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2915    };
2916
2917    proto::SubscriptionUsage {
2918        model_requests_usage_amount: usage.model_requests as u32,
2919        model_requests_usage_limit: Some(proto::UsageLimit {
2920            variant: Some(match model_requests_limit(plan, feature_flags) {
2921                cloud_llm_client::UsageLimit::Limited(limit) => {
2922                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2923                        limit: limit as u32,
2924                    })
2925                }
2926                cloud_llm_client::UsageLimit::Unlimited => {
2927                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2928                }
2929            }),
2930        }),
2931        edit_predictions_usage_amount: usage.edit_predictions as u32,
2932        edit_predictions_usage_limit: Some(proto::UsageLimit {
2933            variant: Some(match plan.edit_predictions_limit() {
2934                cloud_llm_client::UsageLimit::Limited(limit) => {
2935                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2936                        limit: limit as u32,
2937                    })
2938                }
2939                cloud_llm_client::UsageLimit::Unlimited => {
2940                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2941                }
2942            }),
2943        }),
2944    }
2945}
2946
2947fn make_default_subscription_usage(
2948    plan: proto::Plan,
2949    feature_flags: &Vec<String>,
2950) -> proto::SubscriptionUsage {
2951    let plan = match plan {
2952        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2953        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2954        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2955    };
2956
2957    proto::SubscriptionUsage {
2958        model_requests_usage_amount: 0,
2959        model_requests_usage_limit: Some(proto::UsageLimit {
2960            variant: Some(match model_requests_limit(plan, feature_flags) {
2961                cloud_llm_client::UsageLimit::Limited(limit) => {
2962                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2963                        limit: limit as u32,
2964                    })
2965                }
2966                cloud_llm_client::UsageLimit::Unlimited => {
2967                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2968                }
2969            }),
2970        }),
2971        edit_predictions_usage_amount: 0,
2972        edit_predictions_usage_limit: Some(proto::UsageLimit {
2973            variant: Some(match plan.edit_predictions_limit() {
2974                cloud_llm_client::UsageLimit::Limited(limit) => {
2975                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2976                        limit: limit as u32,
2977                    })
2978                }
2979                cloud_llm_client::UsageLimit::Unlimited => {
2980                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2981                }
2982            }),
2983        }),
2984    }
2985}
2986
2987async fn update_user_plan(session: &Session) -> Result<()> {
2988    let db = session.db().await;
2989
2990    let update_user_plan = make_update_user_plan_message(
2991        session.principal.user(),
2992        session.is_staff(),
2993        &db.0,
2994        session.app_state.llm_db.clone(),
2995    )
2996    .await?;
2997
2998    session
2999        .peer
3000        .send(session.connection_id, update_user_plan)
3001        .trace_err();
3002
3003    Ok(())
3004}
3005
3006async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3007    subscribe_user_to_channels(session.user_id(), &session).await?;
3008    Ok(())
3009}
3010
3011async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3012    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3013    let mut pool = session.connection_pool().await;
3014    for membership in &channels_for_user.channel_memberships {
3015        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3016    }
3017    session.peer.send(
3018        session.connection_id,
3019        build_update_user_channels(&channels_for_user),
3020    )?;
3021    session.peer.send(
3022        session.connection_id,
3023        build_channels_update(channels_for_user),
3024    )?;
3025    Ok(())
3026}
3027
3028/// Creates a new channel.
3029async fn create_channel(
3030    request: proto::CreateChannel,
3031    response: Response<proto::CreateChannel>,
3032    session: Session,
3033) -> Result<()> {
3034    let db = session.db().await;
3035
3036    let parent_id = request.parent_id.map(ChannelId::from_proto);
3037    let (channel, membership) = db
3038        .create_channel(&request.name, parent_id, session.user_id())
3039        .await?;
3040
3041    let root_id = channel.root_id();
3042    let channel = Channel::from_model(channel);
3043
3044    response.send(proto::CreateChannelResponse {
3045        channel: Some(channel.to_proto()),
3046        parent_id: request.parent_id,
3047    })?;
3048
3049    let mut connection_pool = session.connection_pool().await;
3050    if let Some(membership) = membership {
3051        connection_pool.subscribe_to_channel(
3052            membership.user_id,
3053            membership.channel_id,
3054            membership.role,
3055        );
3056        let update = proto::UpdateUserChannels {
3057            channel_memberships: vec![proto::ChannelMembership {
3058                channel_id: membership.channel_id.to_proto(),
3059                role: membership.role.into(),
3060            }],
3061            ..Default::default()
3062        };
3063        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3064            session.peer.send(connection_id, update.clone())?;
3065        }
3066    }
3067
3068    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3069        if !role.can_see_channel(channel.visibility) {
3070            continue;
3071        }
3072
3073        let update = proto::UpdateChannels {
3074            channels: vec![channel.to_proto()],
3075            ..Default::default()
3076        };
3077        session.peer.send(connection_id, update.clone())?;
3078    }
3079
3080    Ok(())
3081}
3082
3083/// Delete a channel
3084async fn delete_channel(
3085    request: proto::DeleteChannel,
3086    response: Response<proto::DeleteChannel>,
3087    session: Session,
3088) -> Result<()> {
3089    let db = session.db().await;
3090
3091    let channel_id = request.channel_id;
3092    let (root_channel, removed_channels) = db
3093        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3094        .await?;
3095    response.send(proto::Ack {})?;
3096
3097    // Notify members of removed channels
3098    let mut update = proto::UpdateChannels::default();
3099    update
3100        .delete_channels
3101        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3102
3103    let connection_pool = session.connection_pool().await;
3104    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3105        session.peer.send(connection_id, update.clone())?;
3106    }
3107
3108    Ok(())
3109}
3110
3111/// Invite someone to join a channel.
3112async fn invite_channel_member(
3113    request: proto::InviteChannelMember,
3114    response: Response<proto::InviteChannelMember>,
3115    session: Session,
3116) -> Result<()> {
3117    let db = session.db().await;
3118    let channel_id = ChannelId::from_proto(request.channel_id);
3119    let invitee_id = UserId::from_proto(request.user_id);
3120    let InviteMemberResult {
3121        channel,
3122        notifications,
3123    } = db
3124        .invite_channel_member(
3125            channel_id,
3126            invitee_id,
3127            session.user_id(),
3128            request.role().into(),
3129        )
3130        .await?;
3131
3132    let update = proto::UpdateChannels {
3133        channel_invitations: vec![channel.to_proto()],
3134        ..Default::default()
3135    };
3136
3137    let connection_pool = session.connection_pool().await;
3138    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3139        session.peer.send(connection_id, update.clone())?;
3140    }
3141
3142    send_notifications(&connection_pool, &session.peer, notifications);
3143
3144    response.send(proto::Ack {})?;
3145    Ok(())
3146}
3147
3148/// remove someone from a channel
3149async fn remove_channel_member(
3150    request: proto::RemoveChannelMember,
3151    response: Response<proto::RemoveChannelMember>,
3152    session: Session,
3153) -> Result<()> {
3154    let db = session.db().await;
3155    let channel_id = ChannelId::from_proto(request.channel_id);
3156    let member_id = UserId::from_proto(request.user_id);
3157
3158    let RemoveChannelMemberResult {
3159        membership_update,
3160        notification_id,
3161    } = db
3162        .remove_channel_member(channel_id, member_id, session.user_id())
3163        .await?;
3164
3165    let mut connection_pool = session.connection_pool().await;
3166    notify_membership_updated(
3167        &mut connection_pool,
3168        membership_update,
3169        member_id,
3170        &session.peer,
3171    );
3172    for connection_id in connection_pool.user_connection_ids(member_id) {
3173        if let Some(notification_id) = notification_id {
3174            session
3175                .peer
3176                .send(
3177                    connection_id,
3178                    proto::DeleteNotification {
3179                        notification_id: notification_id.to_proto(),
3180                    },
3181                )
3182                .trace_err();
3183        }
3184    }
3185
3186    response.send(proto::Ack {})?;
3187    Ok(())
3188}
3189
3190/// Toggle the channel between public and private.
3191/// Care is taken to maintain the invariant that public channels only descend from public channels,
3192/// (though members-only channels can appear at any point in the hierarchy).
3193async fn set_channel_visibility(
3194    request: proto::SetChannelVisibility,
3195    response: Response<proto::SetChannelVisibility>,
3196    session: Session,
3197) -> Result<()> {
3198    let db = session.db().await;
3199    let channel_id = ChannelId::from_proto(request.channel_id);
3200    let visibility = request.visibility().into();
3201
3202    let channel_model = db
3203        .set_channel_visibility(channel_id, visibility, session.user_id())
3204        .await?;
3205    let root_id = channel_model.root_id();
3206    let channel = Channel::from_model(channel_model);
3207
3208    let mut connection_pool = session.connection_pool().await;
3209    for (user_id, role) in connection_pool
3210        .channel_user_ids(root_id)
3211        .collect::<Vec<_>>()
3212        .into_iter()
3213    {
3214        let update = if role.can_see_channel(channel.visibility) {
3215            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3216            proto::UpdateChannels {
3217                channels: vec![channel.to_proto()],
3218                ..Default::default()
3219            }
3220        } else {
3221            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3222            proto::UpdateChannels {
3223                delete_channels: vec![channel.id.to_proto()],
3224                ..Default::default()
3225            }
3226        };
3227
3228        for connection_id in connection_pool.user_connection_ids(user_id) {
3229            session.peer.send(connection_id, update.clone())?;
3230        }
3231    }
3232
3233    response.send(proto::Ack {})?;
3234    Ok(())
3235}
3236
3237/// Alter the role for a user in the channel.
3238async fn set_channel_member_role(
3239    request: proto::SetChannelMemberRole,
3240    response: Response<proto::SetChannelMemberRole>,
3241    session: Session,
3242) -> Result<()> {
3243    let db = session.db().await;
3244    let channel_id = ChannelId::from_proto(request.channel_id);
3245    let member_id = UserId::from_proto(request.user_id);
3246    let result = db
3247        .set_channel_member_role(
3248            channel_id,
3249            session.user_id(),
3250            member_id,
3251            request.role().into(),
3252        )
3253        .await?;
3254
3255    match result {
3256        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3257            let mut connection_pool = session.connection_pool().await;
3258            notify_membership_updated(
3259                &mut connection_pool,
3260                membership_update,
3261                member_id,
3262                &session.peer,
3263            )
3264        }
3265        db::SetMemberRoleResult::InviteUpdated(channel) => {
3266            let update = proto::UpdateChannels {
3267                channel_invitations: vec![channel.to_proto()],
3268                ..Default::default()
3269            };
3270
3271            for connection_id in session
3272                .connection_pool()
3273                .await
3274                .user_connection_ids(member_id)
3275            {
3276                session.peer.send(connection_id, update.clone())?;
3277            }
3278        }
3279    }
3280
3281    response.send(proto::Ack {})?;
3282    Ok(())
3283}
3284
3285/// Change the name of a channel
3286async fn rename_channel(
3287    request: proto::RenameChannel,
3288    response: Response<proto::RenameChannel>,
3289    session: Session,
3290) -> Result<()> {
3291    let db = session.db().await;
3292    let channel_id = ChannelId::from_proto(request.channel_id);
3293    let channel_model = db
3294        .rename_channel(channel_id, session.user_id(), &request.name)
3295        .await?;
3296    let root_id = channel_model.root_id();
3297    let channel = Channel::from_model(channel_model);
3298
3299    response.send(proto::RenameChannelResponse {
3300        channel: Some(channel.to_proto()),
3301    })?;
3302
3303    let connection_pool = session.connection_pool().await;
3304    let update = proto::UpdateChannels {
3305        channels: vec![channel.to_proto()],
3306        ..Default::default()
3307    };
3308    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3309        if role.can_see_channel(channel.visibility) {
3310            session.peer.send(connection_id, update.clone())?;
3311        }
3312    }
3313
3314    Ok(())
3315}
3316
3317/// Move a channel to a new parent.
3318async fn move_channel(
3319    request: proto::MoveChannel,
3320    response: Response<proto::MoveChannel>,
3321    session: Session,
3322) -> Result<()> {
3323    let channel_id = ChannelId::from_proto(request.channel_id);
3324    let to = ChannelId::from_proto(request.to);
3325
3326    let (root_id, channels) = session
3327        .db()
3328        .await
3329        .move_channel(channel_id, to, session.user_id())
3330        .await?;
3331
3332    let connection_pool = session.connection_pool().await;
3333    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3334        let channels = channels
3335            .iter()
3336            .filter_map(|channel| {
3337                if role.can_see_channel(channel.visibility) {
3338                    Some(channel.to_proto())
3339                } else {
3340                    None
3341                }
3342            })
3343            .collect::<Vec<_>>();
3344        if channels.is_empty() {
3345            continue;
3346        }
3347
3348        let update = proto::UpdateChannels {
3349            channels,
3350            ..Default::default()
3351        };
3352
3353        session.peer.send(connection_id, update.clone())?;
3354    }
3355
3356    response.send(Ack {})?;
3357    Ok(())
3358}
3359
3360async fn reorder_channel(
3361    request: proto::ReorderChannel,
3362    response: Response<proto::ReorderChannel>,
3363    session: Session,
3364) -> Result<()> {
3365    let channel_id = ChannelId::from_proto(request.channel_id);
3366    let direction = request.direction();
3367
3368    let updated_channels = session
3369        .db()
3370        .await
3371        .reorder_channel(channel_id, direction, session.user_id())
3372        .await?;
3373
3374    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3375        let connection_pool = session.connection_pool().await;
3376        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3377            let channels = updated_channels
3378                .iter()
3379                .filter_map(|channel| {
3380                    if role.can_see_channel(channel.visibility) {
3381                        Some(channel.to_proto())
3382                    } else {
3383                        None
3384                    }
3385                })
3386                .collect::<Vec<_>>();
3387
3388            if channels.is_empty() {
3389                continue;
3390            }
3391
3392            let update = proto::UpdateChannels {
3393                channels,
3394                ..Default::default()
3395            };
3396
3397            session.peer.send(connection_id, update.clone())?;
3398        }
3399    }
3400
3401    response.send(Ack {})?;
3402    Ok(())
3403}
3404
3405/// Get the list of channel members
3406async fn get_channel_members(
3407    request: proto::GetChannelMembers,
3408    response: Response<proto::GetChannelMembers>,
3409    session: Session,
3410) -> Result<()> {
3411    let db = session.db().await;
3412    let channel_id = ChannelId::from_proto(request.channel_id);
3413    let limit = if request.limit == 0 {
3414        u16::MAX as u64
3415    } else {
3416        request.limit
3417    };
3418    let (members, users) = db
3419        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3420        .await?;
3421    response.send(proto::GetChannelMembersResponse { members, users })?;
3422    Ok(())
3423}
3424
3425/// Accept or decline a channel invitation.
3426async fn respond_to_channel_invite(
3427    request: proto::RespondToChannelInvite,
3428    response: Response<proto::RespondToChannelInvite>,
3429    session: Session,
3430) -> Result<()> {
3431    let db = session.db().await;
3432    let channel_id = ChannelId::from_proto(request.channel_id);
3433    let RespondToChannelInvite {
3434        membership_update,
3435        notifications,
3436    } = db
3437        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3438        .await?;
3439
3440    let mut connection_pool = session.connection_pool().await;
3441    if let Some(membership_update) = membership_update {
3442        notify_membership_updated(
3443            &mut connection_pool,
3444            membership_update,
3445            session.user_id(),
3446            &session.peer,
3447        );
3448    } else {
3449        let update = proto::UpdateChannels {
3450            remove_channel_invitations: vec![channel_id.to_proto()],
3451            ..Default::default()
3452        };
3453
3454        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3455            session.peer.send(connection_id, update.clone())?;
3456        }
3457    };
3458
3459    send_notifications(&connection_pool, &session.peer, notifications);
3460
3461    response.send(proto::Ack {})?;
3462
3463    Ok(())
3464}
3465
3466/// Join the channels' room
3467async fn join_channel(
3468    request: proto::JoinChannel,
3469    response: Response<proto::JoinChannel>,
3470    session: Session,
3471) -> Result<()> {
3472    let channel_id = ChannelId::from_proto(request.channel_id);
3473    join_channel_internal(channel_id, Box::new(response), session).await
3474}
3475
3476trait JoinChannelInternalResponse {
3477    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3478}
3479impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3480    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3481        Response::<proto::JoinChannel>::send(self, result)
3482    }
3483}
3484impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3485    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3486        Response::<proto::JoinRoom>::send(self, result)
3487    }
3488}
3489
3490async fn join_channel_internal(
3491    channel_id: ChannelId,
3492    response: Box<impl JoinChannelInternalResponse>,
3493    session: Session,
3494) -> Result<()> {
3495    let joined_room = {
3496        let mut db = session.db().await;
3497        // If zed quits without leaving the room, and the user re-opens zed before the
3498        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3499        // room they were in.
3500        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3501            tracing::info!(
3502                stale_connection_id = %connection,
3503                "cleaning up stale connection",
3504            );
3505            drop(db);
3506            leave_room_for_session(&session, connection).await?;
3507            db = session.db().await;
3508        }
3509
3510        let (joined_room, membership_updated, role) = db
3511            .join_channel(channel_id, session.user_id(), session.connection_id)
3512            .await?;
3513
3514        let live_kit_connection_info =
3515            session
3516                .app_state
3517                .livekit_client
3518                .as_ref()
3519                .and_then(|live_kit| {
3520                    let (can_publish, token) = if role == ChannelRole::Guest {
3521                        (
3522                            false,
3523                            live_kit
3524                                .guest_token(
3525                                    &joined_room.room.livekit_room,
3526                                    &session.user_id().to_string(),
3527                                )
3528                                .trace_err()?,
3529                        )
3530                    } else {
3531                        (
3532                            true,
3533                            live_kit
3534                                .room_token(
3535                                    &joined_room.room.livekit_room,
3536                                    &session.user_id().to_string(),
3537                                )
3538                                .trace_err()?,
3539                        )
3540                    };
3541
3542                    Some(LiveKitConnectionInfo {
3543                        server_url: live_kit.url().into(),
3544                        token,
3545                        can_publish,
3546                    })
3547                });
3548
3549        response.send(proto::JoinRoomResponse {
3550            room: Some(joined_room.room.clone()),
3551            channel_id: joined_room
3552                .channel
3553                .as_ref()
3554                .map(|channel| channel.id.to_proto()),
3555            live_kit_connection_info,
3556        })?;
3557
3558        let mut connection_pool = session.connection_pool().await;
3559        if let Some(membership_updated) = membership_updated {
3560            notify_membership_updated(
3561                &mut connection_pool,
3562                membership_updated,
3563                session.user_id(),
3564                &session.peer,
3565            );
3566        }
3567
3568        room_updated(&joined_room.room, &session.peer);
3569
3570        joined_room
3571    };
3572
3573    channel_updated(
3574        &joined_room.channel.context("channel not returned")?,
3575        &joined_room.room,
3576        &session.peer,
3577        &*session.connection_pool().await,
3578    );
3579
3580    update_user_contacts(session.user_id(), &session).await?;
3581    Ok(())
3582}
3583
3584/// Start editing the channel notes
3585async fn join_channel_buffer(
3586    request: proto::JoinChannelBuffer,
3587    response: Response<proto::JoinChannelBuffer>,
3588    session: Session,
3589) -> Result<()> {
3590    let db = session.db().await;
3591    let channel_id = ChannelId::from_proto(request.channel_id);
3592
3593    let open_response = db
3594        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3595        .await?;
3596
3597    let collaborators = open_response.collaborators.clone();
3598    response.send(open_response)?;
3599
3600    let update = UpdateChannelBufferCollaborators {
3601        channel_id: channel_id.to_proto(),
3602        collaborators: collaborators.clone(),
3603    };
3604    channel_buffer_updated(
3605        session.connection_id,
3606        collaborators
3607            .iter()
3608            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3609        &update,
3610        &session.peer,
3611    );
3612
3613    Ok(())
3614}
3615
3616/// Edit the channel notes
3617async fn update_channel_buffer(
3618    request: proto::UpdateChannelBuffer,
3619    session: Session,
3620) -> Result<()> {
3621    let db = session.db().await;
3622    let channel_id = ChannelId::from_proto(request.channel_id);
3623
3624    let (collaborators, epoch, version) = db
3625        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3626        .await?;
3627
3628    channel_buffer_updated(
3629        session.connection_id,
3630        collaborators.clone(),
3631        &proto::UpdateChannelBuffer {
3632            channel_id: channel_id.to_proto(),
3633            operations: request.operations,
3634        },
3635        &session.peer,
3636    );
3637
3638    let pool = &*session.connection_pool().await;
3639
3640    let non_collaborators =
3641        pool.channel_connection_ids(channel_id)
3642            .filter_map(|(connection_id, _)| {
3643                if collaborators.contains(&connection_id) {
3644                    None
3645                } else {
3646                    Some(connection_id)
3647                }
3648            });
3649
3650    broadcast(None, non_collaborators, |peer_id| {
3651        session.peer.send(
3652            peer_id,
3653            proto::UpdateChannels {
3654                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3655                    channel_id: channel_id.to_proto(),
3656                    epoch: epoch as u64,
3657                    version: version.clone(),
3658                }],
3659                ..Default::default()
3660            },
3661        )
3662    });
3663
3664    Ok(())
3665}
3666
3667/// Rejoin the channel notes after a connection blip
3668async fn rejoin_channel_buffers(
3669    request: proto::RejoinChannelBuffers,
3670    response: Response<proto::RejoinChannelBuffers>,
3671    session: Session,
3672) -> Result<()> {
3673    let db = session.db().await;
3674    let buffers = db
3675        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3676        .await?;
3677
3678    for rejoined_buffer in &buffers {
3679        let collaborators_to_notify = rejoined_buffer
3680            .buffer
3681            .collaborators
3682            .iter()
3683            .filter_map(|c| Some(c.peer_id?.into()));
3684        channel_buffer_updated(
3685            session.connection_id,
3686            collaborators_to_notify,
3687            &proto::UpdateChannelBufferCollaborators {
3688                channel_id: rejoined_buffer.buffer.channel_id,
3689                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3690            },
3691            &session.peer,
3692        );
3693    }
3694
3695    response.send(proto::RejoinChannelBuffersResponse {
3696        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3697    })?;
3698
3699    Ok(())
3700}
3701
3702/// Stop editing the channel notes
3703async fn leave_channel_buffer(
3704    request: proto::LeaveChannelBuffer,
3705    response: Response<proto::LeaveChannelBuffer>,
3706    session: Session,
3707) -> Result<()> {
3708    let db = session.db().await;
3709    let channel_id = ChannelId::from_proto(request.channel_id);
3710
3711    let left_buffer = db
3712        .leave_channel_buffer(channel_id, session.connection_id)
3713        .await?;
3714
3715    response.send(Ack {})?;
3716
3717    channel_buffer_updated(
3718        session.connection_id,
3719        left_buffer.connections,
3720        &proto::UpdateChannelBufferCollaborators {
3721            channel_id: channel_id.to_proto(),
3722            collaborators: left_buffer.collaborators,
3723        },
3724        &session.peer,
3725    );
3726
3727    Ok(())
3728}
3729
3730fn channel_buffer_updated<T: EnvelopedMessage>(
3731    sender_id: ConnectionId,
3732    collaborators: impl IntoIterator<Item = ConnectionId>,
3733    message: &T,
3734    peer: &Peer,
3735) {
3736    broadcast(Some(sender_id), collaborators, |peer_id| {
3737        peer.send(peer_id, message.clone())
3738    });
3739}
3740
3741fn send_notifications(
3742    connection_pool: &ConnectionPool,
3743    peer: &Peer,
3744    notifications: db::NotificationBatch,
3745) {
3746    for (user_id, notification) in notifications {
3747        for connection_id in connection_pool.user_connection_ids(user_id) {
3748            if let Err(error) = peer.send(
3749                connection_id,
3750                proto::AddNotification {
3751                    notification: Some(notification.clone()),
3752                },
3753            ) {
3754                tracing::error!(
3755                    "failed to send notification to {:?} {}",
3756                    connection_id,
3757                    error
3758                );
3759            }
3760        }
3761    }
3762}
3763
3764/// Send a message to the channel
3765async fn send_channel_message(
3766    request: proto::SendChannelMessage,
3767    response: Response<proto::SendChannelMessage>,
3768    session: Session,
3769) -> Result<()> {
3770    // Validate the message body.
3771    let body = request.body.trim().to_string();
3772    if body.len() > MAX_MESSAGE_LEN {
3773        return Err(anyhow!("message is too long"))?;
3774    }
3775    if body.is_empty() {
3776        return Err(anyhow!("message can't be blank"))?;
3777    }
3778
3779    // TODO: adjust mentions if body is trimmed
3780
3781    let timestamp = OffsetDateTime::now_utc();
3782    let nonce = request.nonce.context("nonce can't be blank")?;
3783
3784    let channel_id = ChannelId::from_proto(request.channel_id);
3785    let CreatedChannelMessage {
3786        message_id,
3787        participant_connection_ids,
3788        notifications,
3789    } = session
3790        .db()
3791        .await
3792        .create_channel_message(
3793            channel_id,
3794            session.user_id(),
3795            &body,
3796            &request.mentions,
3797            timestamp,
3798            nonce.clone().into(),
3799            request.reply_to_message_id.map(MessageId::from_proto),
3800        )
3801        .await?;
3802
3803    let message = proto::ChannelMessage {
3804        sender_id: session.user_id().to_proto(),
3805        id: message_id.to_proto(),
3806        body,
3807        mentions: request.mentions,
3808        timestamp: timestamp.unix_timestamp() as u64,
3809        nonce: Some(nonce),
3810        reply_to_message_id: request.reply_to_message_id,
3811        edited_at: None,
3812    };
3813    broadcast(
3814        Some(session.connection_id),
3815        participant_connection_ids.clone(),
3816        |connection| {
3817            session.peer.send(
3818                connection,
3819                proto::ChannelMessageSent {
3820                    channel_id: channel_id.to_proto(),
3821                    message: Some(message.clone()),
3822                },
3823            )
3824        },
3825    );
3826    response.send(proto::SendChannelMessageResponse {
3827        message: Some(message),
3828    })?;
3829
3830    let pool = &*session.connection_pool().await;
3831    let non_participants =
3832        pool.channel_connection_ids(channel_id)
3833            .filter_map(|(connection_id, _)| {
3834                if participant_connection_ids.contains(&connection_id) {
3835                    None
3836                } else {
3837                    Some(connection_id)
3838                }
3839            });
3840    broadcast(None, non_participants, |peer_id| {
3841        session.peer.send(
3842            peer_id,
3843            proto::UpdateChannels {
3844                latest_channel_message_ids: vec![proto::ChannelMessageId {
3845                    channel_id: channel_id.to_proto(),
3846                    message_id: message_id.to_proto(),
3847                }],
3848                ..Default::default()
3849            },
3850        )
3851    });
3852    send_notifications(pool, &session.peer, notifications);
3853
3854    Ok(())
3855}
3856
3857/// Delete a channel message
3858async fn remove_channel_message(
3859    request: proto::RemoveChannelMessage,
3860    response: Response<proto::RemoveChannelMessage>,
3861    session: Session,
3862) -> Result<()> {
3863    let channel_id = ChannelId::from_proto(request.channel_id);
3864    let message_id = MessageId::from_proto(request.message_id);
3865    let (connection_ids, existing_notification_ids) = session
3866        .db()
3867        .await
3868        .remove_channel_message(channel_id, message_id, session.user_id())
3869        .await?;
3870
3871    broadcast(
3872        Some(session.connection_id),
3873        connection_ids,
3874        move |connection| {
3875            session.peer.send(connection, request.clone())?;
3876
3877            for notification_id in &existing_notification_ids {
3878                session.peer.send(
3879                    connection,
3880                    proto::DeleteNotification {
3881                        notification_id: (*notification_id).to_proto(),
3882                    },
3883                )?;
3884            }
3885
3886            Ok(())
3887        },
3888    );
3889    response.send(proto::Ack {})?;
3890    Ok(())
3891}
3892
3893async fn update_channel_message(
3894    request: proto::UpdateChannelMessage,
3895    response: Response<proto::UpdateChannelMessage>,
3896    session: Session,
3897) -> Result<()> {
3898    let channel_id = ChannelId::from_proto(request.channel_id);
3899    let message_id = MessageId::from_proto(request.message_id);
3900    let updated_at = OffsetDateTime::now_utc();
3901    let UpdatedChannelMessage {
3902        message_id,
3903        participant_connection_ids,
3904        notifications,
3905        reply_to_message_id,
3906        timestamp,
3907        deleted_mention_notification_ids,
3908        updated_mention_notifications,
3909    } = session
3910        .db()
3911        .await
3912        .update_channel_message(
3913            channel_id,
3914            message_id,
3915            session.user_id(),
3916            request.body.as_str(),
3917            &request.mentions,
3918            updated_at,
3919        )
3920        .await?;
3921
3922    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3923
3924    let message = proto::ChannelMessage {
3925        sender_id: session.user_id().to_proto(),
3926        id: message_id.to_proto(),
3927        body: request.body.clone(),
3928        mentions: request.mentions.clone(),
3929        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3930        nonce: Some(nonce),
3931        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3932        edited_at: Some(updated_at.unix_timestamp() as u64),
3933    };
3934
3935    response.send(proto::Ack {})?;
3936
3937    let pool = &*session.connection_pool().await;
3938    broadcast(
3939        Some(session.connection_id),
3940        participant_connection_ids,
3941        |connection| {
3942            session.peer.send(
3943                connection,
3944                proto::ChannelMessageUpdate {
3945                    channel_id: channel_id.to_proto(),
3946                    message: Some(message.clone()),
3947                },
3948            )?;
3949
3950            for notification_id in &deleted_mention_notification_ids {
3951                session.peer.send(
3952                    connection,
3953                    proto::DeleteNotification {
3954                        notification_id: (*notification_id).to_proto(),
3955                    },
3956                )?;
3957            }
3958
3959            for notification in &updated_mention_notifications {
3960                session.peer.send(
3961                    connection,
3962                    proto::UpdateNotification {
3963                        notification: Some(notification.clone()),
3964                    },
3965                )?;
3966            }
3967
3968            Ok(())
3969        },
3970    );
3971
3972    send_notifications(pool, &session.peer, notifications);
3973
3974    Ok(())
3975}
3976
3977/// Mark a channel message as read
3978async fn acknowledge_channel_message(
3979    request: proto::AckChannelMessage,
3980    session: Session,
3981) -> Result<()> {
3982    let channel_id = ChannelId::from_proto(request.channel_id);
3983    let message_id = MessageId::from_proto(request.message_id);
3984    let notifications = session
3985        .db()
3986        .await
3987        .observe_channel_message(channel_id, session.user_id(), message_id)
3988        .await?;
3989    send_notifications(
3990        &*session.connection_pool().await,
3991        &session.peer,
3992        notifications,
3993    );
3994    Ok(())
3995}
3996
3997/// Mark a buffer version as synced
3998async fn acknowledge_buffer_version(
3999    request: proto::AckBufferOperation,
4000    session: Session,
4001) -> Result<()> {
4002    let buffer_id = BufferId::from_proto(request.buffer_id);
4003    session
4004        .db()
4005        .await
4006        .observe_buffer_version(
4007            buffer_id,
4008            session.user_id(),
4009            request.epoch as i32,
4010            &request.version,
4011        )
4012        .await?;
4013    Ok(())
4014}
4015
4016/// Get a Supermaven API key for the user
4017async fn get_supermaven_api_key(
4018    _request: proto::GetSupermavenApiKey,
4019    response: Response<proto::GetSupermavenApiKey>,
4020    session: Session,
4021) -> Result<()> {
4022    let user_id: String = session.user_id().to_string();
4023    if !session.is_staff() {
4024        return Err(anyhow!("supermaven not enabled for this account"))?;
4025    }
4026
4027    let email = session.email().context("user must have an email")?;
4028
4029    let supermaven_admin_api = session
4030        .supermaven_client
4031        .as_ref()
4032        .context("supermaven not configured")?;
4033
4034    let result = supermaven_admin_api
4035        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4036        .await?;
4037
4038    response.send(proto::GetSupermavenApiKeyResponse {
4039        api_key: result.api_key,
4040    })?;
4041
4042    Ok(())
4043}
4044
4045/// Start receiving chat updates for a channel
4046async fn join_channel_chat(
4047    request: proto::JoinChannelChat,
4048    response: Response<proto::JoinChannelChat>,
4049    session: Session,
4050) -> Result<()> {
4051    let channel_id = ChannelId::from_proto(request.channel_id);
4052
4053    let db = session.db().await;
4054    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4055        .await?;
4056    let messages = db
4057        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4058        .await?;
4059    response.send(proto::JoinChannelChatResponse {
4060        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4061        messages,
4062    })?;
4063    Ok(())
4064}
4065
4066/// Stop receiving chat updates for a channel
4067async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4068    let channel_id = ChannelId::from_proto(request.channel_id);
4069    session
4070        .db()
4071        .await
4072        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4073        .await?;
4074    Ok(())
4075}
4076
4077/// Retrieve the chat history for a channel
4078async fn get_channel_messages(
4079    request: proto::GetChannelMessages,
4080    response: Response<proto::GetChannelMessages>,
4081    session: Session,
4082) -> Result<()> {
4083    let channel_id = ChannelId::from_proto(request.channel_id);
4084    let messages = session
4085        .db()
4086        .await
4087        .get_channel_messages(
4088            channel_id,
4089            session.user_id(),
4090            MESSAGE_COUNT_PER_PAGE,
4091            Some(MessageId::from_proto(request.before_message_id)),
4092        )
4093        .await?;
4094    response.send(proto::GetChannelMessagesResponse {
4095        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4096        messages,
4097    })?;
4098    Ok(())
4099}
4100
4101/// Retrieve specific chat messages
4102async fn get_channel_messages_by_id(
4103    request: proto::GetChannelMessagesById,
4104    response: Response<proto::GetChannelMessagesById>,
4105    session: Session,
4106) -> Result<()> {
4107    let message_ids = request
4108        .message_ids
4109        .iter()
4110        .map(|id| MessageId::from_proto(*id))
4111        .collect::<Vec<_>>();
4112    let messages = session
4113        .db()
4114        .await
4115        .get_channel_messages_by_id(session.user_id(), &message_ids)
4116        .await?;
4117    response.send(proto::GetChannelMessagesResponse {
4118        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4119        messages,
4120    })?;
4121    Ok(())
4122}
4123
4124/// Retrieve the current users notifications
4125async fn get_notifications(
4126    request: proto::GetNotifications,
4127    response: Response<proto::GetNotifications>,
4128    session: Session,
4129) -> Result<()> {
4130    let notifications = session
4131        .db()
4132        .await
4133        .get_notifications(
4134            session.user_id(),
4135            NOTIFICATION_COUNT_PER_PAGE,
4136            request.before_id.map(db::NotificationId::from_proto),
4137        )
4138        .await?;
4139    response.send(proto::GetNotificationsResponse {
4140        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4141        notifications,
4142    })?;
4143    Ok(())
4144}
4145
4146/// Mark notifications as read
4147async fn mark_notification_as_read(
4148    request: proto::MarkNotificationRead,
4149    response: Response<proto::MarkNotificationRead>,
4150    session: Session,
4151) -> Result<()> {
4152    let database = &session.db().await;
4153    let notifications = database
4154        .mark_notification_as_read_by_id(
4155            session.user_id(),
4156            NotificationId::from_proto(request.notification_id),
4157        )
4158        .await?;
4159    send_notifications(
4160        &*session.connection_pool().await,
4161        &session.peer,
4162        notifications,
4163    );
4164    response.send(proto::Ack {})?;
4165    Ok(())
4166}
4167
4168/// Get the current users information
4169async fn get_private_user_info(
4170    _request: proto::GetPrivateUserInfo,
4171    response: Response<proto::GetPrivateUserInfo>,
4172    session: Session,
4173) -> Result<()> {
4174    let db = session.db().await;
4175
4176    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4177    let user = db
4178        .get_user_by_id(session.user_id())
4179        .await?
4180        .context("user not found")?;
4181    let flags = db.get_user_flags(session.user_id()).await?;
4182
4183    response.send(proto::GetPrivateUserInfoResponse {
4184        metrics_id,
4185        staff: user.admin,
4186        flags,
4187        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4188    })?;
4189    Ok(())
4190}
4191
4192/// Accept the terms of service (tos) on behalf of the current user
4193async fn accept_terms_of_service(
4194    _request: proto::AcceptTermsOfService,
4195    response: Response<proto::AcceptTermsOfService>,
4196    session: Session,
4197) -> Result<()> {
4198    let db = session.db().await;
4199
4200    let accepted_tos_at = Utc::now();
4201    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4202        .await?;
4203
4204    response.send(proto::AcceptTermsOfServiceResponse {
4205        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4206    })?;
4207
4208    // When the user accepts the terms of service, we want to refresh their LLM
4209    // token to grant access.
4210    session
4211        .peer
4212        .send(session.connection_id, proto::RefreshLlmToken {})?;
4213
4214    Ok(())
4215}
4216
4217async fn get_llm_api_token(
4218    _request: proto::GetLlmToken,
4219    response: Response<proto::GetLlmToken>,
4220    session: Session,
4221) -> Result<()> {
4222    let db = session.db().await;
4223
4224    let flags = db.get_user_flags(session.user_id()).await?;
4225
4226    let user_id = session.user_id();
4227    let user = db
4228        .get_user_by_id(user_id)
4229        .await?
4230        .with_context(|| format!("user {user_id} not found"))?;
4231
4232    if user.accepted_tos_at.is_none() {
4233        Err(anyhow!("terms of service not accepted"))?
4234    }
4235
4236    let stripe_client = session
4237        .app_state
4238        .stripe_client
4239        .as_ref()
4240        .context("failed to retrieve Stripe client")?;
4241
4242    let stripe_billing = session
4243        .app_state
4244        .stripe_billing
4245        .as_ref()
4246        .context("failed to retrieve Stripe billing object")?;
4247
4248    let billing_customer = if let Some(billing_customer) =
4249        db.get_billing_customer_by_user_id(user.id).await?
4250    {
4251        billing_customer
4252    } else {
4253        let customer_id = stripe_billing
4254            .find_or_create_customer_by_email(user.email_address.as_deref())
4255            .await?;
4256
4257        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4258            .await?
4259            .context("billing customer not found")?
4260    };
4261
4262    let billing_subscription =
4263        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4264            billing_subscription
4265        } else {
4266            let stripe_customer_id =
4267                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4268
4269            let stripe_subscription = stripe_billing
4270                .subscribe_to_zed_free(stripe_customer_id)
4271                .await?;
4272
4273            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4274                billing_customer_id: billing_customer.id,
4275                kind: Some(SubscriptionKind::ZedFree),
4276                stripe_subscription_id: stripe_subscription.id.to_string(),
4277                stripe_subscription_status: stripe_subscription.status.into(),
4278                stripe_cancellation_reason: None,
4279                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4280                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4281            })
4282            .await?
4283        };
4284
4285    let billing_preferences = db.get_billing_preferences(user.id).await?;
4286
4287    let token = LlmTokenClaims::create(
4288        &user,
4289        session.is_staff(),
4290        billing_customer,
4291        billing_preferences,
4292        &flags,
4293        billing_subscription,
4294        session.system_id.clone(),
4295        &session.app_state.config,
4296    )?;
4297    response.send(proto::GetLlmTokenResponse { token })?;
4298    Ok(())
4299}
4300
4301fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4302    let message = match message {
4303        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4304        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4305        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4306        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4307        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4308            code: frame.code.into(),
4309            reason: frame.reason.as_str().to_owned().into(),
4310        })),
4311        // We should never receive a frame while reading the message, according
4312        // to the `tungstenite` maintainers:
4313        //
4314        // > It cannot occur when you read messages from the WebSocket, but it
4315        // > can be used when you want to send the raw frames (e.g. you want to
4316        // > send the frames to the WebSocket without composing the full message first).
4317        // >
4318        // > — https://github.com/snapview/tungstenite-rs/issues/268
4319        TungsteniteMessage::Frame(_) => {
4320            bail!("received an unexpected frame while reading the message")
4321        }
4322    };
4323
4324    Ok(message)
4325}
4326
4327fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4328    match message {
4329        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4330        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4331        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4332        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4333        AxumMessage::Close(frame) => {
4334            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4335                code: frame.code.into(),
4336                reason: frame.reason.as_ref().into(),
4337            }))
4338        }
4339    }
4340}
4341
4342fn notify_membership_updated(
4343    connection_pool: &mut ConnectionPool,
4344    result: MembershipUpdated,
4345    user_id: UserId,
4346    peer: &Peer,
4347) {
4348    for membership in &result.new_channels.channel_memberships {
4349        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4350    }
4351    for channel_id in &result.removed_channels {
4352        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4353    }
4354
4355    let user_channels_update = proto::UpdateUserChannels {
4356        channel_memberships: result
4357            .new_channels
4358            .channel_memberships
4359            .iter()
4360            .map(|cm| proto::ChannelMembership {
4361                channel_id: cm.channel_id.to_proto(),
4362                role: cm.role.into(),
4363            })
4364            .collect(),
4365        ..Default::default()
4366    };
4367
4368    let mut update = build_channels_update(result.new_channels);
4369    update.delete_channels = result
4370        .removed_channels
4371        .into_iter()
4372        .map(|id| id.to_proto())
4373        .collect();
4374    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4375
4376    for connection_id in connection_pool.user_connection_ids(user_id) {
4377        peer.send(connection_id, user_channels_update.clone())
4378            .trace_err();
4379        peer.send(connection_id, update.clone()).trace_err();
4380    }
4381}
4382
4383fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4384    proto::UpdateUserChannels {
4385        channel_memberships: channels
4386            .channel_memberships
4387            .iter()
4388            .map(|m| proto::ChannelMembership {
4389                channel_id: m.channel_id.to_proto(),
4390                role: m.role.into(),
4391            })
4392            .collect(),
4393        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4394        observed_channel_message_id: channels.observed_channel_messages.clone(),
4395    }
4396}
4397
4398fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4399    let mut update = proto::UpdateChannels::default();
4400
4401    for channel in channels.channels {
4402        update.channels.push(channel.to_proto());
4403    }
4404
4405    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4406    update.latest_channel_message_ids = channels.latest_channel_messages;
4407
4408    for (channel_id, participants) in channels.channel_participants {
4409        update
4410            .channel_participants
4411            .push(proto::ChannelParticipants {
4412                channel_id: channel_id.to_proto(),
4413                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4414            });
4415    }
4416
4417    for channel in channels.invited_channels {
4418        update.channel_invitations.push(channel.to_proto());
4419    }
4420
4421    update
4422}
4423
4424fn build_initial_contacts_update(
4425    contacts: Vec<db::Contact>,
4426    pool: &ConnectionPool,
4427) -> proto::UpdateContacts {
4428    let mut update = proto::UpdateContacts::default();
4429
4430    for contact in contacts {
4431        match contact {
4432            db::Contact::Accepted { user_id, busy } => {
4433                update.contacts.push(contact_for_user(user_id, busy, pool));
4434            }
4435            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4436            db::Contact::Incoming { user_id } => {
4437                update
4438                    .incoming_requests
4439                    .push(proto::IncomingContactRequest {
4440                        requester_id: user_id.to_proto(),
4441                    })
4442            }
4443        }
4444    }
4445
4446    update
4447}
4448
4449fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4450    proto::Contact {
4451        user_id: user_id.to_proto(),
4452        online: pool.is_user_online(user_id),
4453        busy,
4454    }
4455}
4456
4457fn room_updated(room: &proto::Room, peer: &Peer) {
4458    broadcast(
4459        None,
4460        room.participants
4461            .iter()
4462            .filter_map(|participant| Some(participant.peer_id?.into())),
4463        |peer_id| {
4464            peer.send(
4465                peer_id,
4466                proto::RoomUpdated {
4467                    room: Some(room.clone()),
4468                },
4469            )
4470        },
4471    );
4472}
4473
4474fn channel_updated(
4475    channel: &db::channel::Model,
4476    room: &proto::Room,
4477    peer: &Peer,
4478    pool: &ConnectionPool,
4479) {
4480    let participants = room
4481        .participants
4482        .iter()
4483        .map(|p| p.user_id)
4484        .collect::<Vec<_>>();
4485
4486    broadcast(
4487        None,
4488        pool.channel_connection_ids(channel.root_id())
4489            .filter_map(|(channel_id, role)| {
4490                role.can_see_channel(channel.visibility)
4491                    .then_some(channel_id)
4492            }),
4493        |peer_id| {
4494            peer.send(
4495                peer_id,
4496                proto::UpdateChannels {
4497                    channel_participants: vec![proto::ChannelParticipants {
4498                        channel_id: channel.id.to_proto(),
4499                        participant_user_ids: participants.clone(),
4500                    }],
4501                    ..Default::default()
4502                },
4503            )
4504        },
4505    );
4506}
4507
4508async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4509    let db = session.db().await;
4510
4511    let contacts = db.get_contacts(user_id).await?;
4512    let busy = db.is_user_busy(user_id).await?;
4513
4514    let pool = session.connection_pool().await;
4515    let updated_contact = contact_for_user(user_id, busy, &pool);
4516    for contact in contacts {
4517        if let db::Contact::Accepted {
4518            user_id: contact_user_id,
4519            ..
4520        } = contact
4521        {
4522            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4523                session
4524                    .peer
4525                    .send(
4526                        contact_conn_id,
4527                        proto::UpdateContacts {
4528                            contacts: vec![updated_contact.clone()],
4529                            remove_contacts: Default::default(),
4530                            incoming_requests: Default::default(),
4531                            remove_incoming_requests: Default::default(),
4532                            outgoing_requests: Default::default(),
4533                            remove_outgoing_requests: Default::default(),
4534                        },
4535                    )
4536                    .trace_err();
4537            }
4538        }
4539    }
4540    Ok(())
4541}
4542
4543async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4544    let mut contacts_to_update = HashSet::default();
4545
4546    let room_id;
4547    let canceled_calls_to_user_ids;
4548    let livekit_room;
4549    let delete_livekit_room;
4550    let room;
4551    let channel;
4552
4553    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4554        contacts_to_update.insert(session.user_id());
4555
4556        for project in left_room.left_projects.values() {
4557            project_left(project, session);
4558        }
4559
4560        room_id = RoomId::from_proto(left_room.room.id);
4561        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4562        livekit_room = mem::take(&mut left_room.room.livekit_room);
4563        delete_livekit_room = left_room.deleted;
4564        room = mem::take(&mut left_room.room);
4565        channel = mem::take(&mut left_room.channel);
4566
4567        room_updated(&room, &session.peer);
4568    } else {
4569        return Ok(());
4570    }
4571
4572    if let Some(channel) = channel {
4573        channel_updated(
4574            &channel,
4575            &room,
4576            &session.peer,
4577            &*session.connection_pool().await,
4578        );
4579    }
4580
4581    {
4582        let pool = session.connection_pool().await;
4583        for canceled_user_id in canceled_calls_to_user_ids {
4584            for connection_id in pool.user_connection_ids(canceled_user_id) {
4585                session
4586                    .peer
4587                    .send(
4588                        connection_id,
4589                        proto::CallCanceled {
4590                            room_id: room_id.to_proto(),
4591                        },
4592                    )
4593                    .trace_err();
4594            }
4595            contacts_to_update.insert(canceled_user_id);
4596        }
4597    }
4598
4599    for contact_user_id in contacts_to_update {
4600        update_user_contacts(contact_user_id, session).await?;
4601    }
4602
4603    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4604        live_kit
4605            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4606            .await
4607            .trace_err();
4608
4609        if delete_livekit_room {
4610            live_kit.delete_room(livekit_room).await.trace_err();
4611        }
4612    }
4613
4614    Ok(())
4615}
4616
4617async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4618    let left_channel_buffers = session
4619        .db()
4620        .await
4621        .leave_channel_buffers(session.connection_id)
4622        .await?;
4623
4624    for left_buffer in left_channel_buffers {
4625        channel_buffer_updated(
4626            session.connection_id,
4627            left_buffer.connections,
4628            &proto::UpdateChannelBufferCollaborators {
4629                channel_id: left_buffer.channel_id.to_proto(),
4630                collaborators: left_buffer.collaborators,
4631            },
4632            &session.peer,
4633        );
4634    }
4635
4636    Ok(())
4637}
4638
4639fn project_left(project: &db::LeftProject, session: &Session) {
4640    for connection_id in &project.connection_ids {
4641        if project.should_unshare {
4642            session
4643                .peer
4644                .send(
4645                    *connection_id,
4646                    proto::UnshareProject {
4647                        project_id: project.id.to_proto(),
4648                    },
4649                )
4650                .trace_err();
4651        } else {
4652            session
4653                .peer
4654                .send(
4655                    *connection_id,
4656                    proto::RemoveProjectCollaborator {
4657                        project_id: project.id.to_proto(),
4658                        peer_id: Some(session.connection_id.into()),
4659                    },
4660                )
4661                .trace_err();
4662        }
4663    }
4664}
4665
4666pub trait ResultExt {
4667    type Ok;
4668
4669    fn trace_err(self) -> Option<Self::Ok>;
4670}
4671
4672impl<T, E> ResultExt for Result<T, E>
4673where
4674    E: std::fmt::Debug,
4675{
4676    type Ok = T;
4677
4678    #[track_caller]
4679    fn trace_err(self) -> Option<T> {
4680        match self {
4681            Ok(value) => Some(value),
4682            Err(error) => {
4683                tracing::error!("{:?}", error);
4684                None
4685            }
4686        }
4687    }
4688}