rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{
   8    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
   9    MIN_ACCOUNT_AGE_FOR_LLM_USE,
  10};
  11use crate::stripe_client::StripeCustomerId;
  12use crate::{
  13    AppState, Error, Result, auth,
  14    db::{
  15        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  16        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  17        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  18        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  19    },
  20    executor::Executor,
  21};
  22use anyhow::{Context as _, anyhow, bail};
  23use async_tungstenite::tungstenite::{
  24    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  25};
  26use axum::headers::UserAgent;
  27use axum::{
  28    Extension, Router, TypedHeader,
  29    body::Body,
  30    extract::{
  31        ConnectInfo, WebSocketUpgrade,
  32        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  33    },
  34    headers::{Header, HeaderName},
  35    http::StatusCode,
  36    middleware,
  37    response::IntoResponse,
  38    routing::get,
  39};
  40use chrono::Utc;
  41use collections::{HashMap, HashSet};
  42pub use connection_pool::{ConnectionPool, ZedVersion};
  43use core::fmt::{self, Debug, Formatter};
  44use reqwest_client::ReqwestClient;
  45use rpc::proto::{MultiLspQuery, split_repository_update};
  46use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  47
  48use futures::{
  49    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  50    stream::FuturesUnordered,
  51};
  52use prometheus::{IntGauge, register_int_gauge};
  53use rpc::{
  54    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  55    proto::{
  56        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  57        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  58    },
  59};
  60use semantic_version::SemanticVersion;
  61use serde::{Serialize, Serializer};
  62use std::{
  63    any::TypeId,
  64    future::Future,
  65    marker::PhantomData,
  66    mem,
  67    net::SocketAddr,
  68    ops::{Deref, DerefMut},
  69    rc::Rc,
  70    sync::{
  71        Arc, OnceLock,
  72        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  73    },
  74    time::{Duration, Instant},
  75};
  76use time::OffsetDateTime;
  77use tokio::sync::{Semaphore, watch};
  78use tower::ServiceBuilder;
  79use tracing::{
  80    Instrument,
  81    field::{self},
  82    info_span, instrument,
  83};
  84
  85pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  86
  87// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  88pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  89
  90const MESSAGE_COUNT_PER_PAGE: usize = 100;
  91const MAX_MESSAGE_LEN: usize = 1024;
  92const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  93const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  94
  95static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  96
  97type MessageHandler =
  98    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  99
 100pub struct ConnectionGuard;
 101
 102impl ConnectionGuard {
 103    pub fn try_acquire() -> Result<Self, ()> {
 104        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 105        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 106            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 107            tracing::error!(
 108                "too many concurrent connections: {}",
 109                current_connections + 1
 110            );
 111            return Err(());
 112        }
 113        Ok(ConnectionGuard)
 114    }
 115}
 116
 117impl Drop for ConnectionGuard {
 118    fn drop(&mut self) {
 119        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 120    }
 121}
 122
 123struct Response<R> {
 124    peer: Arc<Peer>,
 125    receipt: Receipt<R>,
 126    responded: Arc<AtomicBool>,
 127}
 128
 129impl<R: RequestMessage> Response<R> {
 130    fn send(self, payload: R::Response) -> Result<()> {
 131        self.responded.store(true, SeqCst);
 132        self.peer.respond(self.receipt, payload)?;
 133        Ok(())
 134    }
 135}
 136
 137#[derive(Clone, Debug)]
 138pub enum Principal {
 139    User(User),
 140    Impersonated { user: User, admin: User },
 141}
 142
 143impl Principal {
 144    fn user(&self) -> &User {
 145        match self {
 146            Principal::User(user) => user,
 147            Principal::Impersonated { user, .. } => user,
 148        }
 149    }
 150
 151    fn update_span(&self, span: &tracing::Span) {
 152        match &self {
 153            Principal::User(user) => {
 154                span.record("user_id", user.id.0);
 155                span.record("login", &user.github_login);
 156            }
 157            Principal::Impersonated { user, admin } => {
 158                span.record("user_id", user.id.0);
 159                span.record("login", &user.github_login);
 160                span.record("impersonator", &admin.github_login);
 161            }
 162        }
 163    }
 164}
 165
 166#[derive(Clone)]
 167struct Session {
 168    principal: Principal,
 169    connection_id: ConnectionId,
 170    db: Arc<tokio::sync::Mutex<DbHandle>>,
 171    peer: Arc<Peer>,
 172    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 173    app_state: Arc<AppState>,
 174    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 175    /// The GeoIP country code for the user.
 176    #[allow(unused)]
 177    geoip_country_code: Option<String>,
 178    system_id: Option<String>,
 179    _executor: Executor,
 180}
 181
 182impl Session {
 183    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 184        #[cfg(test)]
 185        tokio::task::yield_now().await;
 186        let guard = self.db.lock().await;
 187        #[cfg(test)]
 188        tokio::task::yield_now().await;
 189        guard
 190    }
 191
 192    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 193        #[cfg(test)]
 194        tokio::task::yield_now().await;
 195        let guard = self.connection_pool.lock();
 196        ConnectionPoolGuard {
 197            guard,
 198            _not_send: PhantomData,
 199        }
 200    }
 201
 202    fn is_staff(&self) -> bool {
 203        match &self.principal {
 204            Principal::User(user) => user.admin,
 205            Principal::Impersonated { .. } => true,
 206        }
 207    }
 208
 209    fn user_id(&self) -> UserId {
 210        match &self.principal {
 211            Principal::User(user) => user.id,
 212            Principal::Impersonated { user, .. } => user.id,
 213        }
 214    }
 215
 216    pub fn email(&self) -> Option<String> {
 217        match &self.principal {
 218            Principal::User(user) => user.email_address.clone(),
 219            Principal::Impersonated { user, .. } => user.email_address.clone(),
 220        }
 221    }
 222}
 223
 224impl Debug for Session {
 225    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 226        let mut result = f.debug_struct("Session");
 227        match &self.principal {
 228            Principal::User(user) => {
 229                result.field("user", &user.github_login);
 230            }
 231            Principal::Impersonated { user, admin } => {
 232                result.field("user", &user.github_login);
 233                result.field("impersonator", &admin.github_login);
 234            }
 235        }
 236        result.field("connection_id", &self.connection_id).finish()
 237    }
 238}
 239
 240struct DbHandle(Arc<Database>);
 241
 242impl Deref for DbHandle {
 243    type Target = Database;
 244
 245    fn deref(&self) -> &Self::Target {
 246        self.0.as_ref()
 247    }
 248}
 249
 250pub struct Server {
 251    id: parking_lot::Mutex<ServerId>,
 252    peer: Arc<Peer>,
 253    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 254    app_state: Arc<AppState>,
 255    handlers: HashMap<TypeId, MessageHandler>,
 256    teardown: watch::Sender<bool>,
 257}
 258
 259pub(crate) struct ConnectionPoolGuard<'a> {
 260    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 261    _not_send: PhantomData<Rc<()>>,
 262}
 263
 264#[derive(Serialize)]
 265pub struct ServerSnapshot<'a> {
 266    peer: &'a Peer,
 267    #[serde(serialize_with = "serialize_deref")]
 268    connection_pool: ConnectionPoolGuard<'a>,
 269}
 270
 271pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 272where
 273    S: Serializer,
 274    T: Deref<Target = U>,
 275    U: Serialize,
 276{
 277    Serialize::serialize(value.deref(), serializer)
 278}
 279
 280impl Server {
 281    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 282        let mut server = Self {
 283            id: parking_lot::Mutex::new(id),
 284            peer: Peer::new(id.0 as u32),
 285            app_state: app_state.clone(),
 286            connection_pool: Default::default(),
 287            handlers: Default::default(),
 288            teardown: watch::channel(false).0,
 289        };
 290
 291        server
 292            .add_request_handler(ping)
 293            .add_request_handler(create_room)
 294            .add_request_handler(join_room)
 295            .add_request_handler(rejoin_room)
 296            .add_request_handler(leave_room)
 297            .add_request_handler(set_room_participant_role)
 298            .add_request_handler(call)
 299            .add_request_handler(cancel_call)
 300            .add_message_handler(decline_call)
 301            .add_request_handler(update_participant_location)
 302            .add_request_handler(share_project)
 303            .add_message_handler(unshare_project)
 304            .add_request_handler(join_project)
 305            .add_message_handler(leave_project)
 306            .add_request_handler(update_project)
 307            .add_request_handler(update_worktree)
 308            .add_request_handler(update_repository)
 309            .add_request_handler(remove_repository)
 310            .add_message_handler(start_language_server)
 311            .add_message_handler(update_language_server)
 312            .add_message_handler(update_diagnostic_summary)
 313            .add_message_handler(update_worktree_settings)
 314            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 316            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 317            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 318            .add_request_handler(forward_find_search_candidates_request)
 319            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 320            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 321            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 322            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 323            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 324            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 325            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 326            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 327            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 328            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 329            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 330            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 331            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 332            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 333            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 334            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 335            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 336            .add_request_handler(
 337                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 338            )
 339            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 340            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 341            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 342            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 343            .add_request_handler(
 344                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 345            )
 346            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 347            .add_request_handler(
 348                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 349            )
 350            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 351            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 352            .add_request_handler(
 353                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 354            )
 355            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 356            .add_request_handler(
 357                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 358            )
 359            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 360            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 361            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 362            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 363            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 364            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 365            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 366            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 367            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 368            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 369            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 370            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 371            .add_request_handler(
 372                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 373            )
 374            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 375            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 376            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 377            .add_request_handler(multi_lsp_query)
 378            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 379            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 380            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 381            .add_message_handler(create_buffer_for_peer)
 382            .add_request_handler(update_buffer)
 383            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 384            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 385            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 386            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 387            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 388            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 389            .add_message_handler(
 390                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 391            )
 392            .add_request_handler(get_users)
 393            .add_request_handler(fuzzy_search_users)
 394            .add_request_handler(request_contact)
 395            .add_request_handler(remove_contact)
 396            .add_request_handler(respond_to_contact_request)
 397            .add_message_handler(subscribe_to_channels)
 398            .add_request_handler(create_channel)
 399            .add_request_handler(delete_channel)
 400            .add_request_handler(invite_channel_member)
 401            .add_request_handler(remove_channel_member)
 402            .add_request_handler(set_channel_member_role)
 403            .add_request_handler(set_channel_visibility)
 404            .add_request_handler(rename_channel)
 405            .add_request_handler(join_channel_buffer)
 406            .add_request_handler(leave_channel_buffer)
 407            .add_message_handler(update_channel_buffer)
 408            .add_request_handler(rejoin_channel_buffers)
 409            .add_request_handler(get_channel_members)
 410            .add_request_handler(respond_to_channel_invite)
 411            .add_request_handler(join_channel)
 412            .add_request_handler(join_channel_chat)
 413            .add_message_handler(leave_channel_chat)
 414            .add_request_handler(send_channel_message)
 415            .add_request_handler(remove_channel_message)
 416            .add_request_handler(update_channel_message)
 417            .add_request_handler(get_channel_messages)
 418            .add_request_handler(get_channel_messages_by_id)
 419            .add_request_handler(get_notifications)
 420            .add_request_handler(mark_notification_as_read)
 421            .add_request_handler(move_channel)
 422            .add_request_handler(reorder_channel)
 423            .add_request_handler(follow)
 424            .add_message_handler(unfollow)
 425            .add_message_handler(update_followers)
 426            .add_request_handler(get_private_user_info)
 427            .add_request_handler(get_llm_api_token)
 428            .add_request_handler(accept_terms_of_service)
 429            .add_message_handler(acknowledge_channel_message)
 430            .add_message_handler(acknowledge_buffer_version)
 431            .add_request_handler(get_supermaven_api_key)
 432            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 433            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 434            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 435            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 436            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 437            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 438            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 439            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 440            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 441            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 442            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 443            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 444            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 445            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 446            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 447            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 448            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 449            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 450            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 451            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 452            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 453            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 454            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 455            .add_message_handler(update_context);
 456
 457        Arc::new(server)
 458    }
 459
 460    pub async fn start(&self) -> Result<()> {
 461        let server_id = *self.id.lock();
 462        let app_state = self.app_state.clone();
 463        let peer = self.peer.clone();
 464        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 465        let pool = self.connection_pool.clone();
 466        let livekit_client = self.app_state.livekit_client.clone();
 467
 468        let span = info_span!("start server");
 469        self.app_state.executor.spawn_detached(
 470            async move {
 471                tracing::info!("waiting for cleanup timeout");
 472                timeout.await;
 473                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 474
 475                app_state
 476                    .db
 477                    .delete_stale_channel_chat_participants(
 478                        &app_state.config.zed_environment,
 479                        server_id,
 480                    )
 481                    .await
 482                    .trace_err();
 483
 484                if let Some((room_ids, channel_ids)) = app_state
 485                    .db
 486                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 487                    .await
 488                    .trace_err()
 489                {
 490                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 491                    tracing::info!(
 492                        stale_channel_buffer_count = channel_ids.len(),
 493                        "retrieved stale channel buffers"
 494                    );
 495
 496                    for channel_id in channel_ids {
 497                        if let Some(refreshed_channel_buffer) = app_state
 498                            .db
 499                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 500                            .await
 501                            .trace_err()
 502                        {
 503                            for connection_id in refreshed_channel_buffer.connection_ids {
 504                                peer.send(
 505                                    connection_id,
 506                                    proto::UpdateChannelBufferCollaborators {
 507                                        channel_id: channel_id.to_proto(),
 508                                        collaborators: refreshed_channel_buffer
 509                                            .collaborators
 510                                            .clone(),
 511                                    },
 512                                )
 513                                .trace_err();
 514                            }
 515                        }
 516                    }
 517
 518                    for room_id in room_ids {
 519                        let mut contacts_to_update = HashSet::default();
 520                        let mut canceled_calls_to_user_ids = Vec::new();
 521                        let mut livekit_room = String::new();
 522                        let mut delete_livekit_room = false;
 523
 524                        if let Some(mut refreshed_room) = app_state
 525                            .db
 526                            .clear_stale_room_participants(room_id, server_id)
 527                            .await
 528                            .trace_err()
 529                        {
 530                            tracing::info!(
 531                                room_id = room_id.0,
 532                                new_participant_count = refreshed_room.room.participants.len(),
 533                                "refreshed room"
 534                            );
 535                            room_updated(&refreshed_room.room, &peer);
 536                            if let Some(channel) = refreshed_room.channel.as_ref() {
 537                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 538                            }
 539                            contacts_to_update
 540                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 541                            contacts_to_update
 542                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 543                            canceled_calls_to_user_ids =
 544                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 545                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 546                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 547                        }
 548
 549                        {
 550                            let pool = pool.lock();
 551                            for canceled_user_id in canceled_calls_to_user_ids {
 552                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 553                                    peer.send(
 554                                        connection_id,
 555                                        proto::CallCanceled {
 556                                            room_id: room_id.to_proto(),
 557                                        },
 558                                    )
 559                                    .trace_err();
 560                                }
 561                            }
 562                        }
 563
 564                        for user_id in contacts_to_update {
 565                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 566                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 567                            if let Some((busy, contacts)) = busy.zip(contacts) {
 568                                let pool = pool.lock();
 569                                let updated_contact = contact_for_user(user_id, busy, &pool);
 570                                for contact in contacts {
 571                                    if let db::Contact::Accepted {
 572                                        user_id: contact_user_id,
 573                                        ..
 574                                    } = contact
 575                                    {
 576                                        for contact_conn_id in
 577                                            pool.user_connection_ids(contact_user_id)
 578                                        {
 579                                            peer.send(
 580                                                contact_conn_id,
 581                                                proto::UpdateContacts {
 582                                                    contacts: vec![updated_contact.clone()],
 583                                                    remove_contacts: Default::default(),
 584                                                    incoming_requests: Default::default(),
 585                                                    remove_incoming_requests: Default::default(),
 586                                                    outgoing_requests: Default::default(),
 587                                                    remove_outgoing_requests: Default::default(),
 588                                                },
 589                                            )
 590                                            .trace_err();
 591                                        }
 592                                    }
 593                                }
 594                            }
 595                        }
 596
 597                        if let Some(live_kit) = livekit_client.as_ref() {
 598                            if delete_livekit_room {
 599                                live_kit.delete_room(livekit_room).await.trace_err();
 600                            }
 601                        }
 602                    }
 603                }
 604
 605                app_state
 606                    .db
 607                    .delete_stale_channel_chat_participants(
 608                        &app_state.config.zed_environment,
 609                        server_id,
 610                    )
 611                    .await
 612                    .trace_err();
 613
 614                app_state
 615                    .db
 616                    .clear_old_worktree_entries(server_id)
 617                    .await
 618                    .trace_err();
 619
 620                app_state
 621                    .db
 622                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 623                    .await
 624                    .trace_err();
 625            }
 626            .instrument(span),
 627        );
 628        Ok(())
 629    }
 630
 631    pub fn teardown(&self) {
 632        self.peer.teardown();
 633        self.connection_pool.lock().reset();
 634        let _ = self.teardown.send(true);
 635    }
 636
 637    #[cfg(test)]
 638    pub fn reset(&self, id: ServerId) {
 639        self.teardown();
 640        *self.id.lock() = id;
 641        self.peer.reset(id.0 as u32);
 642        let _ = self.teardown.send(false);
 643    }
 644
 645    #[cfg(test)]
 646    pub fn id(&self) -> ServerId {
 647        *self.id.lock()
 648    }
 649
 650    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 651    where
 652        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 653        Fut: 'static + Send + Future<Output = Result<()>>,
 654        M: EnvelopedMessage,
 655    {
 656        let prev_handler = self.handlers.insert(
 657            TypeId::of::<M>(),
 658            Box::new(move |envelope, session| {
 659                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 660                let received_at = envelope.received_at;
 661                tracing::info!("message received");
 662                let start_time = Instant::now();
 663                let future = (handler)(*envelope, session);
 664                async move {
 665                    let result = future.await;
 666                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 667                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 668                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 669                    let payload_type = M::NAME;
 670
 671                    match result {
 672                        Err(error) => {
 673                            tracing::error!(
 674                                ?error,
 675                                total_duration_ms,
 676                                processing_duration_ms,
 677                                queue_duration_ms,
 678                                payload_type,
 679                                "error handling message"
 680                            )
 681                        }
 682                        Ok(()) => tracing::info!(
 683                            total_duration_ms,
 684                            processing_duration_ms,
 685                            queue_duration_ms,
 686                            "finished handling message"
 687                        ),
 688                    }
 689                }
 690                .boxed()
 691            }),
 692        );
 693        if prev_handler.is_some() {
 694            panic!("registered a handler for the same message twice");
 695        }
 696        self
 697    }
 698
 699    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 700    where
 701        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 702        Fut: 'static + Send + Future<Output = Result<()>>,
 703        M: EnvelopedMessage,
 704    {
 705        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 706        self
 707    }
 708
 709    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 710    where
 711        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 712        Fut: Send + Future<Output = Result<()>>,
 713        M: RequestMessage,
 714    {
 715        let handler = Arc::new(handler);
 716        self.add_handler(move |envelope, session| {
 717            let receipt = envelope.receipt();
 718            let handler = handler.clone();
 719            async move {
 720                let peer = session.peer.clone();
 721                let responded = Arc::new(AtomicBool::default());
 722                let response = Response {
 723                    peer: peer.clone(),
 724                    responded: responded.clone(),
 725                    receipt,
 726                };
 727                match (handler)(envelope.payload, response, session).await {
 728                    Ok(()) => {
 729                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 730                            Ok(())
 731                        } else {
 732                            Err(anyhow!("handler did not send a response"))?
 733                        }
 734                    }
 735                    Err(error) => {
 736                        let proto_err = match &error {
 737                            Error::Internal(err) => err.to_proto(),
 738                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 739                        };
 740                        peer.respond_with_error(receipt, proto_err)?;
 741                        Err(error)
 742                    }
 743                }
 744            }
 745        })
 746    }
 747
 748    pub fn handle_connection(
 749        self: &Arc<Self>,
 750        connection: Connection,
 751        address: String,
 752        principal: Principal,
 753        zed_version: ZedVersion,
 754        user_agent: Option<String>,
 755        geoip_country_code: Option<String>,
 756        system_id: Option<String>,
 757        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 758        executor: Executor,
 759        connection_guard: Option<ConnectionGuard>,
 760    ) -> impl Future<Output = ()> + use<> {
 761        let this = self.clone();
 762        let span = info_span!("handle connection", %address,
 763            connection_id=field::Empty,
 764            user_id=field::Empty,
 765            login=field::Empty,
 766            impersonator=field::Empty,
 767            user_agent=field::Empty,
 768            geoip_country_code=field::Empty
 769        );
 770        principal.update_span(&span);
 771        if let Some(user_agent) = user_agent {
 772            span.record("user_agent", user_agent);
 773        }
 774
 775        if let Some(country_code) = geoip_country_code.as_ref() {
 776            span.record("geoip_country_code", country_code);
 777        }
 778
 779        let mut teardown = self.teardown.subscribe();
 780        async move {
 781            if *teardown.borrow() {
 782                tracing::error!("server is tearing down");
 783                return
 784            }
 785
 786            let (connection_id, handle_io, mut incoming_rx) = this
 787                .peer
 788                .add_connection(connection, {
 789                    let executor = executor.clone();
 790                    move |duration| executor.sleep(duration)
 791                });
 792            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 793
 794            tracing::info!("connection opened");
 795
 796            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 797            let http_client = match ReqwestClient::user_agent(&user_agent) {
 798                Ok(http_client) => Arc::new(http_client),
 799                Err(error) => {
 800                    tracing::error!(?error, "failed to create HTTP client");
 801                    return;
 802                }
 803            };
 804
 805            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 806                    supermaven_admin_api_key.to_string(),
 807                    http_client.clone(),
 808                )));
 809
 810            let session = Session {
 811                principal: principal.clone(),
 812                connection_id,
 813                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 814                peer: this.peer.clone(),
 815                connection_pool: this.connection_pool.clone(),
 816                app_state: this.app_state.clone(),
 817                geoip_country_code,
 818                system_id,
 819                _executor: executor.clone(),
 820                supermaven_client,
 821            };
 822
 823            if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
 824                tracing::error!(?error, "failed to send initial client update");
 825                return;
 826            }
 827            drop(connection_guard);
 828
 829            let handle_io = handle_io.fuse();
 830            futures::pin_mut!(handle_io);
 831
 832            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 833            // This prevents deadlocks when e.g., client A performs a request to client B and
 834            // client B performs a request to client A. If both clients stop processing further
 835            // messages until their respective request completes, they won't have a chance to
 836            // respond to the other client's request and cause a deadlock.
 837            //
 838            // This arrangement ensures we will attempt to process earlier messages first, but fall
 839            // back to processing messages arrived later in the spirit of making progress.
 840            let mut foreground_message_handlers = FuturesUnordered::new();
 841            let concurrent_handlers = Arc::new(Semaphore::new(256));
 842            loop {
 843                let next_message = async {
 844                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 845                    let message = incoming_rx.next().await;
 846                    (permit, message)
 847                }.fuse();
 848                futures::pin_mut!(next_message);
 849                futures::select_biased! {
 850                    _ = teardown.changed().fuse() => return,
 851                    result = handle_io => {
 852                        if let Err(error) = result {
 853                            tracing::error!(?error, "error handling I/O");
 854                        }
 855                        break;
 856                    }
 857                    _ = foreground_message_handlers.next() => {}
 858                    next_message = next_message => {
 859                        let (permit, message) = next_message;
 860                        if let Some(message) = message {
 861                            let type_name = message.payload_type_name();
 862                            // note: we copy all the fields from the parent span so we can query them in the logs.
 863                            // (https://github.com/tokio-rs/tracing/issues/2670).
 864                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 865                                user_id=field::Empty,
 866                                login=field::Empty,
 867                                impersonator=field::Empty,
 868                                multi_lsp_query_request=field::Empty,
 869                            );
 870                            principal.update_span(&span);
 871                            let span_enter = span.enter();
 872                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 873                                let is_background = message.is_background();
 874                                let handle_message = (handler)(message, session.clone());
 875                                drop(span_enter);
 876
 877                                let handle_message = async move {
 878                                    handle_message.await;
 879                                    drop(permit);
 880                                }.instrument(span);
 881                                if is_background {
 882                                    executor.spawn_detached(handle_message);
 883                                } else {
 884                                    foreground_message_handlers.push(handle_message);
 885                                }
 886                            } else {
 887                                tracing::error!("no message handler");
 888                            }
 889                        } else {
 890                            tracing::info!("connection closed");
 891                            break;
 892                        }
 893                    }
 894                }
 895            }
 896
 897            drop(foreground_message_handlers);
 898            tracing::info!("signing out");
 899            if let Err(error) = connection_lost(session, teardown, executor).await {
 900                tracing::error!(?error, "error signing out");
 901            }
 902
 903        }.instrument(span)
 904    }
 905
 906    async fn send_initial_client_update(
 907        &self,
 908        connection_id: ConnectionId,
 909        zed_version: ZedVersion,
 910        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 911        session: &Session,
 912    ) -> Result<()> {
 913        self.peer.send(
 914            connection_id,
 915            proto::Hello {
 916                peer_id: Some(connection_id.into()),
 917            },
 918        )?;
 919        tracing::info!("sent hello message");
 920        if let Some(send_connection_id) = send_connection_id.take() {
 921            let _ = send_connection_id.send(connection_id);
 922        }
 923
 924        match &session.principal {
 925            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 926                if !user.connected_once {
 927                    self.peer.send(connection_id, proto::ShowContacts {})?;
 928                    self.app_state
 929                        .db
 930                        .set_user_connected_once(user.id, true)
 931                        .await?;
 932                }
 933
 934                update_user_plan(session).await?;
 935
 936                let contacts = self.app_state.db.get_contacts(user.id).await?;
 937
 938                {
 939                    let mut pool = self.connection_pool.lock();
 940                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 941                    self.peer.send(
 942                        connection_id,
 943                        build_initial_contacts_update(contacts, &pool),
 944                    )?;
 945                }
 946
 947                if should_auto_subscribe_to_channels(zed_version) {
 948                    subscribe_user_to_channels(user.id, session).await?;
 949                }
 950
 951                if let Some(incoming_call) =
 952                    self.app_state.db.incoming_call_for_user(user.id).await?
 953                {
 954                    self.peer.send(connection_id, incoming_call)?;
 955                }
 956
 957                update_user_contacts(user.id, session).await?;
 958            }
 959        }
 960
 961        Ok(())
 962    }
 963
 964    pub async fn invite_code_redeemed(
 965        self: &Arc<Self>,
 966        inviter_id: UserId,
 967        invitee_id: UserId,
 968    ) -> Result<()> {
 969        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 970            if let Some(code) = &user.invite_code {
 971                let pool = self.connection_pool.lock();
 972                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 973                for connection_id in pool.user_connection_ids(inviter_id) {
 974                    self.peer.send(
 975                        connection_id,
 976                        proto::UpdateContacts {
 977                            contacts: vec![invitee_contact.clone()],
 978                            ..Default::default()
 979                        },
 980                    )?;
 981                    self.peer.send(
 982                        connection_id,
 983                        proto::UpdateInviteInfo {
 984                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 985                            count: user.invite_count as u32,
 986                        },
 987                    )?;
 988                }
 989            }
 990        }
 991        Ok(())
 992    }
 993
 994    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 995        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 996            if let Some(invite_code) = &user.invite_code {
 997                let pool = self.connection_pool.lock();
 998                for connection_id in pool.user_connection_ids(user_id) {
 999                    self.peer.send(
1000                        connection_id,
1001                        proto::UpdateInviteInfo {
1002                            url: format!(
1003                                "{}{}",
1004                                self.app_state.config.invite_link_prefix, invite_code
1005                            ),
1006                            count: user.invite_count as u32,
1007                        },
1008                    )?;
1009                }
1010            }
1011        }
1012        Ok(())
1013    }
1014
1015    pub async fn update_plan_for_user(
1016        self: &Arc<Self>,
1017        user_id: UserId,
1018        update_user_plan: proto::UpdateUserPlan,
1019    ) -> Result<()> {
1020        let pool = self.connection_pool.lock();
1021        for connection_id in pool.user_connection_ids(user_id) {
1022            self.peer
1023                .send(connection_id, update_user_plan.clone())
1024                .trace_err();
1025        }
1026
1027        Ok(())
1028    }
1029
1030    /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1031    /// message on the Collab server.
1032    ///
1033    /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1034    pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1035        let user = self
1036            .app_state
1037            .db
1038            .get_user_by_id(user_id)
1039            .await?
1040            .context("user not found")?;
1041
1042        let update_user_plan = make_update_user_plan_message(
1043            &user,
1044            user.admin,
1045            &self.app_state.db,
1046            self.app_state.llm_db.clone(),
1047        )
1048        .await?;
1049
1050        self.update_plan_for_user(user_id, update_user_plan).await
1051    }
1052
1053    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1054        let pool = self.connection_pool.lock();
1055        for connection_id in pool.user_connection_ids(user_id) {
1056            self.peer
1057                .send(connection_id, proto::RefreshLlmToken {})
1058                .trace_err();
1059        }
1060    }
1061
1062    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1063        ServerSnapshot {
1064            connection_pool: ConnectionPoolGuard {
1065                guard: self.connection_pool.lock(),
1066                _not_send: PhantomData,
1067            },
1068            peer: &self.peer,
1069        }
1070    }
1071}
1072
1073impl Deref for ConnectionPoolGuard<'_> {
1074    type Target = ConnectionPool;
1075
1076    fn deref(&self) -> &Self::Target {
1077        &self.guard
1078    }
1079}
1080
1081impl DerefMut for ConnectionPoolGuard<'_> {
1082    fn deref_mut(&mut self) -> &mut Self::Target {
1083        &mut self.guard
1084    }
1085}
1086
1087impl Drop for ConnectionPoolGuard<'_> {
1088    fn drop(&mut self) {
1089        #[cfg(test)]
1090        self.check_invariants();
1091    }
1092}
1093
1094fn broadcast<F>(
1095    sender_id: Option<ConnectionId>,
1096    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1097    mut f: F,
1098) where
1099    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1100{
1101    for receiver_id in receiver_ids {
1102        if Some(receiver_id) != sender_id {
1103            if let Err(error) = f(receiver_id) {
1104                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1105            }
1106        }
1107    }
1108}
1109
1110pub struct ProtocolVersion(u32);
1111
1112impl Header for ProtocolVersion {
1113    fn name() -> &'static HeaderName {
1114        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1115        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1116    }
1117
1118    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1119    where
1120        Self: Sized,
1121        I: Iterator<Item = &'i axum::http::HeaderValue>,
1122    {
1123        let version = values
1124            .next()
1125            .ok_or_else(axum::headers::Error::invalid)?
1126            .to_str()
1127            .map_err(|_| axum::headers::Error::invalid())?
1128            .parse()
1129            .map_err(|_| axum::headers::Error::invalid())?;
1130        Ok(Self(version))
1131    }
1132
1133    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1134        values.extend([self.0.to_string().parse().unwrap()]);
1135    }
1136}
1137
1138pub struct AppVersionHeader(SemanticVersion);
1139impl Header for AppVersionHeader {
1140    fn name() -> &'static HeaderName {
1141        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1142        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1143    }
1144
1145    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1146    where
1147        Self: Sized,
1148        I: Iterator<Item = &'i axum::http::HeaderValue>,
1149    {
1150        let version = values
1151            .next()
1152            .ok_or_else(axum::headers::Error::invalid)?
1153            .to_str()
1154            .map_err(|_| axum::headers::Error::invalid())?
1155            .parse()
1156            .map_err(|_| axum::headers::Error::invalid())?;
1157        Ok(Self(version))
1158    }
1159
1160    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1161        values.extend([self.0.to_string().parse().unwrap()]);
1162    }
1163}
1164
1165pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1166    Router::new()
1167        .route("/rpc", get(handle_websocket_request))
1168        .layer(
1169            ServiceBuilder::new()
1170                .layer(Extension(server.app_state.clone()))
1171                .layer(middleware::from_fn(auth::validate_header)),
1172        )
1173        .route("/metrics", get(handle_metrics))
1174        .layer(Extension(server))
1175}
1176
1177pub async fn handle_websocket_request(
1178    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1179    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1180    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1181    Extension(server): Extension<Arc<Server>>,
1182    Extension(principal): Extension<Principal>,
1183    user_agent: Option<TypedHeader<UserAgent>>,
1184    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1185    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1186    ws: WebSocketUpgrade,
1187) -> axum::response::Response {
1188    if protocol_version != rpc::PROTOCOL_VERSION {
1189        return (
1190            StatusCode::UPGRADE_REQUIRED,
1191            "client must be upgraded".to_string(),
1192        )
1193            .into_response();
1194    }
1195
1196    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1197        return (
1198            StatusCode::UPGRADE_REQUIRED,
1199            "no version header found".to_string(),
1200        )
1201            .into_response();
1202    };
1203
1204    if !version.can_collaborate() {
1205        return (
1206            StatusCode::UPGRADE_REQUIRED,
1207            "client must be upgraded".to_string(),
1208        )
1209            .into_response();
1210    }
1211
1212    let socket_address = socket_address.to_string();
1213
1214    // Acquire connection guard before WebSocket upgrade
1215    let connection_guard = match ConnectionGuard::try_acquire() {
1216        Ok(guard) => guard,
1217        Err(()) => {
1218            return (
1219                StatusCode::SERVICE_UNAVAILABLE,
1220                "Too many concurrent connections",
1221            )
1222                .into_response();
1223        }
1224    };
1225
1226    ws.on_upgrade(move |socket| {
1227        let socket = socket
1228            .map_ok(to_tungstenite_message)
1229            .err_into()
1230            .with(|message| async move { to_axum_message(message) });
1231        let connection = Connection::new(Box::pin(socket));
1232        async move {
1233            server
1234                .handle_connection(
1235                    connection,
1236                    socket_address,
1237                    principal,
1238                    version,
1239                    user_agent.map(|header| header.to_string()),
1240                    country_code_header.map(|header| header.to_string()),
1241                    system_id_header.map(|header| header.to_string()),
1242                    None,
1243                    Executor::Production,
1244                    Some(connection_guard),
1245                )
1246                .await;
1247        }
1248    })
1249}
1250
1251pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1252    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1253    let connections_metric = CONNECTIONS_METRIC
1254        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1255
1256    let connections = server
1257        .connection_pool
1258        .lock()
1259        .connections()
1260        .filter(|connection| !connection.admin)
1261        .count();
1262    connections_metric.set(connections as _);
1263
1264    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1265    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1266        register_int_gauge!(
1267            "shared_projects",
1268            "number of open projects with one or more guests"
1269        )
1270        .unwrap()
1271    });
1272
1273    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1274    shared_projects_metric.set(shared_projects as _);
1275
1276    let encoder = prometheus::TextEncoder::new();
1277    let metric_families = prometheus::gather();
1278    let encoded_metrics = encoder
1279        .encode_to_string(&metric_families)
1280        .map_err(|err| anyhow!("{err}"))?;
1281    Ok(encoded_metrics)
1282}
1283
1284#[instrument(err, skip(executor))]
1285async fn connection_lost(
1286    session: Session,
1287    mut teardown: watch::Receiver<bool>,
1288    executor: Executor,
1289) -> Result<()> {
1290    session.peer.disconnect(session.connection_id);
1291    session
1292        .connection_pool()
1293        .await
1294        .remove_connection(session.connection_id)?;
1295
1296    session
1297        .db()
1298        .await
1299        .connection_lost(session.connection_id)
1300        .await
1301        .trace_err();
1302
1303    futures::select_biased! {
1304        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1305
1306            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1307            leave_room_for_session(&session, session.connection_id).await.trace_err();
1308            leave_channel_buffers_for_session(&session)
1309                .await
1310                .trace_err();
1311
1312            if !session
1313                .connection_pool()
1314                .await
1315                .is_user_online(session.user_id())
1316            {
1317                let db = session.db().await;
1318                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1319                    room_updated(&room, &session.peer);
1320                }
1321            }
1322
1323            update_user_contacts(session.user_id(), &session).await?;
1324        },
1325        _ = teardown.changed().fuse() => {}
1326    }
1327
1328    Ok(())
1329}
1330
1331/// Acknowledges a ping from a client, used to keep the connection alive.
1332async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1333    response.send(proto::Ack {})?;
1334    Ok(())
1335}
1336
1337/// Creates a new room for calling (outside of channels)
1338async fn create_room(
1339    _request: proto::CreateRoom,
1340    response: Response<proto::CreateRoom>,
1341    session: Session,
1342) -> Result<()> {
1343    let livekit_room = nanoid::nanoid!(30);
1344
1345    let live_kit_connection_info = util::maybe!(async {
1346        let live_kit = session.app_state.livekit_client.as_ref();
1347        let live_kit = live_kit?;
1348        let user_id = session.user_id().to_string();
1349
1350        let token = live_kit
1351            .room_token(&livekit_room, &user_id.to_string())
1352            .trace_err()?;
1353
1354        Some(proto::LiveKitConnectionInfo {
1355            server_url: live_kit.url().into(),
1356            token,
1357            can_publish: true,
1358        })
1359    })
1360    .await;
1361
1362    let room = session
1363        .db()
1364        .await
1365        .create_room(session.user_id(), session.connection_id, &livekit_room)
1366        .await?;
1367
1368    response.send(proto::CreateRoomResponse {
1369        room: Some(room.clone()),
1370        live_kit_connection_info,
1371    })?;
1372
1373    update_user_contacts(session.user_id(), &session).await?;
1374    Ok(())
1375}
1376
1377/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1378async fn join_room(
1379    request: proto::JoinRoom,
1380    response: Response<proto::JoinRoom>,
1381    session: Session,
1382) -> Result<()> {
1383    let room_id = RoomId::from_proto(request.id);
1384
1385    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1386
1387    if let Some(channel_id) = channel_id {
1388        return join_channel_internal(channel_id, Box::new(response), session).await;
1389    }
1390
1391    let joined_room = {
1392        let room = session
1393            .db()
1394            .await
1395            .join_room(room_id, session.user_id(), session.connection_id)
1396            .await?;
1397        room_updated(&room.room, &session.peer);
1398        room.into_inner()
1399    };
1400
1401    for connection_id in session
1402        .connection_pool()
1403        .await
1404        .user_connection_ids(session.user_id())
1405    {
1406        session
1407            .peer
1408            .send(
1409                connection_id,
1410                proto::CallCanceled {
1411                    room_id: room_id.to_proto(),
1412                },
1413            )
1414            .trace_err();
1415    }
1416
1417    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1418    {
1419        live_kit
1420            .room_token(
1421                &joined_room.room.livekit_room,
1422                &session.user_id().to_string(),
1423            )
1424            .trace_err()
1425            .map(|token| proto::LiveKitConnectionInfo {
1426                server_url: live_kit.url().into(),
1427                token,
1428                can_publish: true,
1429            })
1430    } else {
1431        None
1432    };
1433
1434    response.send(proto::JoinRoomResponse {
1435        room: Some(joined_room.room),
1436        channel_id: None,
1437        live_kit_connection_info,
1438    })?;
1439
1440    update_user_contacts(session.user_id(), &session).await?;
1441    Ok(())
1442}
1443
1444/// Rejoin room is used to reconnect to a room after connection errors.
1445async fn rejoin_room(
1446    request: proto::RejoinRoom,
1447    response: Response<proto::RejoinRoom>,
1448    session: Session,
1449) -> Result<()> {
1450    let room;
1451    let channel;
1452    {
1453        let mut rejoined_room = session
1454            .db()
1455            .await
1456            .rejoin_room(request, session.user_id(), session.connection_id)
1457            .await?;
1458
1459        response.send(proto::RejoinRoomResponse {
1460            room: Some(rejoined_room.room.clone()),
1461            reshared_projects: rejoined_room
1462                .reshared_projects
1463                .iter()
1464                .map(|project| proto::ResharedProject {
1465                    id: project.id.to_proto(),
1466                    collaborators: project
1467                        .collaborators
1468                        .iter()
1469                        .map(|collaborator| collaborator.to_proto())
1470                        .collect(),
1471                })
1472                .collect(),
1473            rejoined_projects: rejoined_room
1474                .rejoined_projects
1475                .iter()
1476                .map(|rejoined_project| rejoined_project.to_proto())
1477                .collect(),
1478        })?;
1479        room_updated(&rejoined_room.room, &session.peer);
1480
1481        for project in &rejoined_room.reshared_projects {
1482            for collaborator in &project.collaborators {
1483                session
1484                    .peer
1485                    .send(
1486                        collaborator.connection_id,
1487                        proto::UpdateProjectCollaborator {
1488                            project_id: project.id.to_proto(),
1489                            old_peer_id: Some(project.old_connection_id.into()),
1490                            new_peer_id: Some(session.connection_id.into()),
1491                        },
1492                    )
1493                    .trace_err();
1494            }
1495
1496            broadcast(
1497                Some(session.connection_id),
1498                project
1499                    .collaborators
1500                    .iter()
1501                    .map(|collaborator| collaborator.connection_id),
1502                |connection_id| {
1503                    session.peer.forward_send(
1504                        session.connection_id,
1505                        connection_id,
1506                        proto::UpdateProject {
1507                            project_id: project.id.to_proto(),
1508                            worktrees: project.worktrees.clone(),
1509                        },
1510                    )
1511                },
1512            );
1513        }
1514
1515        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1516
1517        let rejoined_room = rejoined_room.into_inner();
1518
1519        room = rejoined_room.room;
1520        channel = rejoined_room.channel;
1521    }
1522
1523    if let Some(channel) = channel {
1524        channel_updated(
1525            &channel,
1526            &room,
1527            &session.peer,
1528            &*session.connection_pool().await,
1529        );
1530    }
1531
1532    update_user_contacts(session.user_id(), &session).await?;
1533    Ok(())
1534}
1535
1536fn notify_rejoined_projects(
1537    rejoined_projects: &mut Vec<RejoinedProject>,
1538    session: &Session,
1539) -> Result<()> {
1540    for project in rejoined_projects.iter() {
1541        for collaborator in &project.collaborators {
1542            session
1543                .peer
1544                .send(
1545                    collaborator.connection_id,
1546                    proto::UpdateProjectCollaborator {
1547                        project_id: project.id.to_proto(),
1548                        old_peer_id: Some(project.old_connection_id.into()),
1549                        new_peer_id: Some(session.connection_id.into()),
1550                    },
1551                )
1552                .trace_err();
1553        }
1554    }
1555
1556    for project in rejoined_projects {
1557        for worktree in mem::take(&mut project.worktrees) {
1558            // Stream this worktree's entries.
1559            let message = proto::UpdateWorktree {
1560                project_id: project.id.to_proto(),
1561                worktree_id: worktree.id,
1562                abs_path: worktree.abs_path.clone(),
1563                root_name: worktree.root_name,
1564                updated_entries: worktree.updated_entries,
1565                removed_entries: worktree.removed_entries,
1566                scan_id: worktree.scan_id,
1567                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1568                updated_repositories: worktree.updated_repositories,
1569                removed_repositories: worktree.removed_repositories,
1570            };
1571            for update in proto::split_worktree_update(message) {
1572                session.peer.send(session.connection_id, update)?;
1573            }
1574
1575            // Stream this worktree's diagnostics.
1576            for summary in worktree.diagnostic_summaries {
1577                session.peer.send(
1578                    session.connection_id,
1579                    proto::UpdateDiagnosticSummary {
1580                        project_id: project.id.to_proto(),
1581                        worktree_id: worktree.id,
1582                        summary: Some(summary),
1583                    },
1584                )?;
1585            }
1586
1587            for settings_file in worktree.settings_files {
1588                session.peer.send(
1589                    session.connection_id,
1590                    proto::UpdateWorktreeSettings {
1591                        project_id: project.id.to_proto(),
1592                        worktree_id: worktree.id,
1593                        path: settings_file.path,
1594                        content: Some(settings_file.content),
1595                        kind: Some(settings_file.kind.to_proto().into()),
1596                    },
1597                )?;
1598            }
1599        }
1600
1601        for repository in mem::take(&mut project.updated_repositories) {
1602            for update in split_repository_update(repository) {
1603                session.peer.send(session.connection_id, update)?;
1604            }
1605        }
1606
1607        for id in mem::take(&mut project.removed_repositories) {
1608            session.peer.send(
1609                session.connection_id,
1610                proto::RemoveRepository {
1611                    project_id: project.id.to_proto(),
1612                    id,
1613                },
1614            )?;
1615        }
1616    }
1617
1618    Ok(())
1619}
1620
1621/// leave room disconnects from the room.
1622async fn leave_room(
1623    _: proto::LeaveRoom,
1624    response: Response<proto::LeaveRoom>,
1625    session: Session,
1626) -> Result<()> {
1627    leave_room_for_session(&session, session.connection_id).await?;
1628    response.send(proto::Ack {})?;
1629    Ok(())
1630}
1631
1632/// Updates the permissions of someone else in the room.
1633async fn set_room_participant_role(
1634    request: proto::SetRoomParticipantRole,
1635    response: Response<proto::SetRoomParticipantRole>,
1636    session: Session,
1637) -> Result<()> {
1638    let user_id = UserId::from_proto(request.user_id);
1639    let role = ChannelRole::from(request.role());
1640
1641    let (livekit_room, can_publish) = {
1642        let room = session
1643            .db()
1644            .await
1645            .set_room_participant_role(
1646                session.user_id(),
1647                RoomId::from_proto(request.room_id),
1648                user_id,
1649                role,
1650            )
1651            .await?;
1652
1653        let livekit_room = room.livekit_room.clone();
1654        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1655        room_updated(&room, &session.peer);
1656        (livekit_room, can_publish)
1657    };
1658
1659    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1660        live_kit
1661            .update_participant(
1662                livekit_room.clone(),
1663                request.user_id.to_string(),
1664                livekit_api::proto::ParticipantPermission {
1665                    can_subscribe: true,
1666                    can_publish,
1667                    can_publish_data: can_publish,
1668                    hidden: false,
1669                    recorder: false,
1670                },
1671            )
1672            .await
1673            .trace_err();
1674    }
1675
1676    response.send(proto::Ack {})?;
1677    Ok(())
1678}
1679
1680/// Call someone else into the current room
1681async fn call(
1682    request: proto::Call,
1683    response: Response<proto::Call>,
1684    session: Session,
1685) -> Result<()> {
1686    let room_id = RoomId::from_proto(request.room_id);
1687    let calling_user_id = session.user_id();
1688    let calling_connection_id = session.connection_id;
1689    let called_user_id = UserId::from_proto(request.called_user_id);
1690    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1691    if !session
1692        .db()
1693        .await
1694        .has_contact(calling_user_id, called_user_id)
1695        .await?
1696    {
1697        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1698    }
1699
1700    let incoming_call = {
1701        let (room, incoming_call) = &mut *session
1702            .db()
1703            .await
1704            .call(
1705                room_id,
1706                calling_user_id,
1707                calling_connection_id,
1708                called_user_id,
1709                initial_project_id,
1710            )
1711            .await?;
1712        room_updated(room, &session.peer);
1713        mem::take(incoming_call)
1714    };
1715    update_user_contacts(called_user_id, &session).await?;
1716
1717    let mut calls = session
1718        .connection_pool()
1719        .await
1720        .user_connection_ids(called_user_id)
1721        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1722        .collect::<FuturesUnordered<_>>();
1723
1724    while let Some(call_response) = calls.next().await {
1725        match call_response.as_ref() {
1726            Ok(_) => {
1727                response.send(proto::Ack {})?;
1728                return Ok(());
1729            }
1730            Err(_) => {
1731                call_response.trace_err();
1732            }
1733        }
1734    }
1735
1736    {
1737        let room = session
1738            .db()
1739            .await
1740            .call_failed(room_id, called_user_id)
1741            .await?;
1742        room_updated(&room, &session.peer);
1743    }
1744    update_user_contacts(called_user_id, &session).await?;
1745
1746    Err(anyhow!("failed to ring user"))?
1747}
1748
1749/// Cancel an outgoing call.
1750async fn cancel_call(
1751    request: proto::CancelCall,
1752    response: Response<proto::CancelCall>,
1753    session: Session,
1754) -> Result<()> {
1755    let called_user_id = UserId::from_proto(request.called_user_id);
1756    let room_id = RoomId::from_proto(request.room_id);
1757    {
1758        let room = session
1759            .db()
1760            .await
1761            .cancel_call(room_id, session.connection_id, called_user_id)
1762            .await?;
1763        room_updated(&room, &session.peer);
1764    }
1765
1766    for connection_id in session
1767        .connection_pool()
1768        .await
1769        .user_connection_ids(called_user_id)
1770    {
1771        session
1772            .peer
1773            .send(
1774                connection_id,
1775                proto::CallCanceled {
1776                    room_id: room_id.to_proto(),
1777                },
1778            )
1779            .trace_err();
1780    }
1781    response.send(proto::Ack {})?;
1782
1783    update_user_contacts(called_user_id, &session).await?;
1784    Ok(())
1785}
1786
1787/// Decline an incoming call.
1788async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1789    let room_id = RoomId::from_proto(message.room_id);
1790    {
1791        let room = session
1792            .db()
1793            .await
1794            .decline_call(Some(room_id), session.user_id())
1795            .await?
1796            .context("declining call")?;
1797        room_updated(&room, &session.peer);
1798    }
1799
1800    for connection_id in session
1801        .connection_pool()
1802        .await
1803        .user_connection_ids(session.user_id())
1804    {
1805        session
1806            .peer
1807            .send(
1808                connection_id,
1809                proto::CallCanceled {
1810                    room_id: room_id.to_proto(),
1811                },
1812            )
1813            .trace_err();
1814    }
1815    update_user_contacts(session.user_id(), &session).await?;
1816    Ok(())
1817}
1818
1819/// Updates other participants in the room with your current location.
1820async fn update_participant_location(
1821    request: proto::UpdateParticipantLocation,
1822    response: Response<proto::UpdateParticipantLocation>,
1823    session: Session,
1824) -> Result<()> {
1825    let room_id = RoomId::from_proto(request.room_id);
1826    let location = request.location.context("invalid location")?;
1827
1828    let db = session.db().await;
1829    let room = db
1830        .update_room_participant_location(room_id, session.connection_id, location)
1831        .await?;
1832
1833    room_updated(&room, &session.peer);
1834    response.send(proto::Ack {})?;
1835    Ok(())
1836}
1837
1838/// Share a project into the room.
1839async fn share_project(
1840    request: proto::ShareProject,
1841    response: Response<proto::ShareProject>,
1842    session: Session,
1843) -> Result<()> {
1844    let (project_id, room) = &*session
1845        .db()
1846        .await
1847        .share_project(
1848            RoomId::from_proto(request.room_id),
1849            session.connection_id,
1850            &request.worktrees,
1851            request.is_ssh_project,
1852        )
1853        .await?;
1854    response.send(proto::ShareProjectResponse {
1855        project_id: project_id.to_proto(),
1856    })?;
1857    room_updated(room, &session.peer);
1858
1859    Ok(())
1860}
1861
1862/// Unshare a project from the room.
1863async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1864    let project_id = ProjectId::from_proto(message.project_id);
1865    unshare_project_internal(project_id, session.connection_id, &session).await
1866}
1867
1868async fn unshare_project_internal(
1869    project_id: ProjectId,
1870    connection_id: ConnectionId,
1871    session: &Session,
1872) -> Result<()> {
1873    let delete = {
1874        let room_guard = session
1875            .db()
1876            .await
1877            .unshare_project(project_id, connection_id)
1878            .await?;
1879
1880        let (delete, room, guest_connection_ids) = &*room_guard;
1881
1882        let message = proto::UnshareProject {
1883            project_id: project_id.to_proto(),
1884        };
1885
1886        broadcast(
1887            Some(connection_id),
1888            guest_connection_ids.iter().copied(),
1889            |conn_id| session.peer.send(conn_id, message.clone()),
1890        );
1891        if let Some(room) = room {
1892            room_updated(room, &session.peer);
1893        }
1894
1895        *delete
1896    };
1897
1898    if delete {
1899        let db = session.db().await;
1900        db.delete_project(project_id).await?;
1901    }
1902
1903    Ok(())
1904}
1905
1906/// Join someone elses shared project.
1907async fn join_project(
1908    request: proto::JoinProject,
1909    response: Response<proto::JoinProject>,
1910    session: Session,
1911) -> Result<()> {
1912    let project_id = ProjectId::from_proto(request.project_id);
1913
1914    tracing::info!(%project_id, "join project");
1915
1916    let db = session.db().await;
1917    let (project, replica_id) = &mut *db
1918        .join_project(
1919            project_id,
1920            session.connection_id,
1921            session.user_id(),
1922            request.committer_name.clone(),
1923            request.committer_email.clone(),
1924        )
1925        .await?;
1926    drop(db);
1927    tracing::info!(%project_id, "join remote project");
1928    let collaborators = project
1929        .collaborators
1930        .iter()
1931        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1932        .map(|collaborator| collaborator.to_proto())
1933        .collect::<Vec<_>>();
1934    let project_id = project.id;
1935    let guest_user_id = session.user_id();
1936
1937    let worktrees = project
1938        .worktrees
1939        .iter()
1940        .map(|(id, worktree)| proto::WorktreeMetadata {
1941            id: *id,
1942            root_name: worktree.root_name.clone(),
1943            visible: worktree.visible,
1944            abs_path: worktree.abs_path.clone(),
1945        })
1946        .collect::<Vec<_>>();
1947
1948    let add_project_collaborator = proto::AddProjectCollaborator {
1949        project_id: project_id.to_proto(),
1950        collaborator: Some(proto::Collaborator {
1951            peer_id: Some(session.connection_id.into()),
1952            replica_id: replica_id.0 as u32,
1953            user_id: guest_user_id.to_proto(),
1954            is_host: false,
1955            committer_name: request.committer_name.clone(),
1956            committer_email: request.committer_email.clone(),
1957        }),
1958    };
1959
1960    for collaborator in &collaborators {
1961        session
1962            .peer
1963            .send(
1964                collaborator.peer_id.unwrap().into(),
1965                add_project_collaborator.clone(),
1966            )
1967            .trace_err();
1968    }
1969
1970    // First, we send the metadata associated with each worktree.
1971    response.send(proto::JoinProjectResponse {
1972        project_id: project.id.0 as u64,
1973        worktrees: worktrees.clone(),
1974        replica_id: replica_id.0 as u32,
1975        collaborators: collaborators.clone(),
1976        language_servers: project.language_servers.clone(),
1977        role: project.role.into(),
1978    })?;
1979
1980    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1981        // Stream this worktree's entries.
1982        let message = proto::UpdateWorktree {
1983            project_id: project_id.to_proto(),
1984            worktree_id,
1985            abs_path: worktree.abs_path.clone(),
1986            root_name: worktree.root_name,
1987            updated_entries: worktree.entries,
1988            removed_entries: Default::default(),
1989            scan_id: worktree.scan_id,
1990            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1991            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1992            removed_repositories: Default::default(),
1993        };
1994        for update in proto::split_worktree_update(message) {
1995            session.peer.send(session.connection_id, update.clone())?;
1996        }
1997
1998        // Stream this worktree's diagnostics.
1999        for summary in worktree.diagnostic_summaries {
2000            session.peer.send(
2001                session.connection_id,
2002                proto::UpdateDiagnosticSummary {
2003                    project_id: project_id.to_proto(),
2004                    worktree_id: worktree.id,
2005                    summary: Some(summary),
2006                },
2007            )?;
2008        }
2009
2010        for settings_file in worktree.settings_files {
2011            session.peer.send(
2012                session.connection_id,
2013                proto::UpdateWorktreeSettings {
2014                    project_id: project_id.to_proto(),
2015                    worktree_id: worktree.id,
2016                    path: settings_file.path,
2017                    content: Some(settings_file.content),
2018                    kind: Some(settings_file.kind.to_proto() as i32),
2019                },
2020            )?;
2021        }
2022    }
2023
2024    for repository in mem::take(&mut project.repositories) {
2025        for update in split_repository_update(repository) {
2026            session.peer.send(session.connection_id, update)?;
2027        }
2028    }
2029
2030    for language_server in &project.language_servers {
2031        session.peer.send(
2032            session.connection_id,
2033            proto::UpdateLanguageServer {
2034                project_id: project_id.to_proto(),
2035                server_name: Some(language_server.name.clone()),
2036                language_server_id: language_server.id,
2037                variant: Some(
2038                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2039                        proto::LspDiskBasedDiagnosticsUpdated {},
2040                    ),
2041                ),
2042            },
2043        )?;
2044    }
2045
2046    Ok(())
2047}
2048
2049/// Leave someone elses shared project.
2050async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2051    let sender_id = session.connection_id;
2052    let project_id = ProjectId::from_proto(request.project_id);
2053    let db = session.db().await;
2054
2055    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2056    tracing::info!(
2057        %project_id,
2058        "leave project"
2059    );
2060
2061    project_left(project, &session);
2062    if let Some(room) = room {
2063        room_updated(room, &session.peer);
2064    }
2065
2066    Ok(())
2067}
2068
2069/// Updates other participants with changes to the project
2070async fn update_project(
2071    request: proto::UpdateProject,
2072    response: Response<proto::UpdateProject>,
2073    session: Session,
2074) -> Result<()> {
2075    let project_id = ProjectId::from_proto(request.project_id);
2076    let (room, guest_connection_ids) = &*session
2077        .db()
2078        .await
2079        .update_project(project_id, session.connection_id, &request.worktrees)
2080        .await?;
2081    broadcast(
2082        Some(session.connection_id),
2083        guest_connection_ids.iter().copied(),
2084        |connection_id| {
2085            session
2086                .peer
2087                .forward_send(session.connection_id, connection_id, request.clone())
2088        },
2089    );
2090    if let Some(room) = room {
2091        room_updated(room, &session.peer);
2092    }
2093    response.send(proto::Ack {})?;
2094
2095    Ok(())
2096}
2097
2098/// Updates other participants with changes to the worktree
2099async fn update_worktree(
2100    request: proto::UpdateWorktree,
2101    response: Response<proto::UpdateWorktree>,
2102    session: Session,
2103) -> Result<()> {
2104    let guest_connection_ids = session
2105        .db()
2106        .await
2107        .update_worktree(&request, session.connection_id)
2108        .await?;
2109
2110    broadcast(
2111        Some(session.connection_id),
2112        guest_connection_ids.iter().copied(),
2113        |connection_id| {
2114            session
2115                .peer
2116                .forward_send(session.connection_id, connection_id, request.clone())
2117        },
2118    );
2119    response.send(proto::Ack {})?;
2120    Ok(())
2121}
2122
2123async fn update_repository(
2124    request: proto::UpdateRepository,
2125    response: Response<proto::UpdateRepository>,
2126    session: Session,
2127) -> Result<()> {
2128    let guest_connection_ids = session
2129        .db()
2130        .await
2131        .update_repository(&request, session.connection_id)
2132        .await?;
2133
2134    broadcast(
2135        Some(session.connection_id),
2136        guest_connection_ids.iter().copied(),
2137        |connection_id| {
2138            session
2139                .peer
2140                .forward_send(session.connection_id, connection_id, request.clone())
2141        },
2142    );
2143    response.send(proto::Ack {})?;
2144    Ok(())
2145}
2146
2147async fn remove_repository(
2148    request: proto::RemoveRepository,
2149    response: Response<proto::RemoveRepository>,
2150    session: Session,
2151) -> Result<()> {
2152    let guest_connection_ids = session
2153        .db()
2154        .await
2155        .remove_repository(&request, session.connection_id)
2156        .await?;
2157
2158    broadcast(
2159        Some(session.connection_id),
2160        guest_connection_ids.iter().copied(),
2161        |connection_id| {
2162            session
2163                .peer
2164                .forward_send(session.connection_id, connection_id, request.clone())
2165        },
2166    );
2167    response.send(proto::Ack {})?;
2168    Ok(())
2169}
2170
2171/// Updates other participants with changes to the diagnostics
2172async fn update_diagnostic_summary(
2173    message: proto::UpdateDiagnosticSummary,
2174    session: Session,
2175) -> Result<()> {
2176    let guest_connection_ids = session
2177        .db()
2178        .await
2179        .update_diagnostic_summary(&message, session.connection_id)
2180        .await?;
2181
2182    broadcast(
2183        Some(session.connection_id),
2184        guest_connection_ids.iter().copied(),
2185        |connection_id| {
2186            session
2187                .peer
2188                .forward_send(session.connection_id, connection_id, message.clone())
2189        },
2190    );
2191
2192    Ok(())
2193}
2194
2195/// Updates other participants with changes to the worktree settings
2196async fn update_worktree_settings(
2197    message: proto::UpdateWorktreeSettings,
2198    session: Session,
2199) -> Result<()> {
2200    let guest_connection_ids = session
2201        .db()
2202        .await
2203        .update_worktree_settings(&message, session.connection_id)
2204        .await?;
2205
2206    broadcast(
2207        Some(session.connection_id),
2208        guest_connection_ids.iter().copied(),
2209        |connection_id| {
2210            session
2211                .peer
2212                .forward_send(session.connection_id, connection_id, message.clone())
2213        },
2214    );
2215
2216    Ok(())
2217}
2218
2219/// Notify other participants that a language server has started.
2220async fn start_language_server(
2221    request: proto::StartLanguageServer,
2222    session: Session,
2223) -> Result<()> {
2224    let guest_connection_ids = session
2225        .db()
2226        .await
2227        .start_language_server(&request, session.connection_id)
2228        .await?;
2229
2230    broadcast(
2231        Some(session.connection_id),
2232        guest_connection_ids.iter().copied(),
2233        |connection_id| {
2234            session
2235                .peer
2236                .forward_send(session.connection_id, connection_id, request.clone())
2237        },
2238    );
2239    Ok(())
2240}
2241
2242/// Notify other participants that a language server has changed.
2243async fn update_language_server(
2244    request: proto::UpdateLanguageServer,
2245    session: Session,
2246) -> Result<()> {
2247    let project_id = ProjectId::from_proto(request.project_id);
2248    let project_connection_ids = session
2249        .db()
2250        .await
2251        .project_connection_ids(project_id, session.connection_id, true)
2252        .await?;
2253    broadcast(
2254        Some(session.connection_id),
2255        project_connection_ids.iter().copied(),
2256        |connection_id| {
2257            session
2258                .peer
2259                .forward_send(session.connection_id, connection_id, request.clone())
2260        },
2261    );
2262    Ok(())
2263}
2264
2265/// forward a project request to the host. These requests should be read only
2266/// as guests are allowed to send them.
2267async fn forward_read_only_project_request<T>(
2268    request: T,
2269    response: Response<T>,
2270    session: Session,
2271) -> Result<()>
2272where
2273    T: EntityMessage + RequestMessage,
2274{
2275    let project_id = ProjectId::from_proto(request.remote_entity_id());
2276    let host_connection_id = session
2277        .db()
2278        .await
2279        .host_for_read_only_project_request(project_id, session.connection_id)
2280        .await?;
2281    let payload = session
2282        .peer
2283        .forward_request(session.connection_id, host_connection_id, request)
2284        .await?;
2285    response.send(payload)?;
2286    Ok(())
2287}
2288
2289async fn forward_find_search_candidates_request(
2290    request: proto::FindSearchCandidates,
2291    response: Response<proto::FindSearchCandidates>,
2292    session: Session,
2293) -> Result<()> {
2294    let project_id = ProjectId::from_proto(request.remote_entity_id());
2295    let host_connection_id = session
2296        .db()
2297        .await
2298        .host_for_read_only_project_request(project_id, session.connection_id)
2299        .await?;
2300    let payload = session
2301        .peer
2302        .forward_request(session.connection_id, host_connection_id, request)
2303        .await?;
2304    response.send(payload)?;
2305    Ok(())
2306}
2307
2308/// forward a project request to the host. These requests are disallowed
2309/// for guests.
2310async fn forward_mutating_project_request<T>(
2311    request: T,
2312    response: Response<T>,
2313    session: Session,
2314) -> Result<()>
2315where
2316    T: EntityMessage + RequestMessage,
2317{
2318    let project_id = ProjectId::from_proto(request.remote_entity_id());
2319
2320    let host_connection_id = session
2321        .db()
2322        .await
2323        .host_for_mutating_project_request(project_id, session.connection_id)
2324        .await?;
2325    let payload = session
2326        .peer
2327        .forward_request(session.connection_id, host_connection_id, request)
2328        .await?;
2329    response.send(payload)?;
2330    Ok(())
2331}
2332
2333async fn multi_lsp_query(
2334    request: MultiLspQuery,
2335    response: Response<MultiLspQuery>,
2336    session: Session,
2337) -> Result<()> {
2338    tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2339    forward_mutating_project_request(request, response, session).await
2340}
2341
2342/// Notify other participants that a new buffer has been created
2343async fn create_buffer_for_peer(
2344    request: proto::CreateBufferForPeer,
2345    session: Session,
2346) -> Result<()> {
2347    session
2348        .db()
2349        .await
2350        .check_user_is_project_host(
2351            ProjectId::from_proto(request.project_id),
2352            session.connection_id,
2353        )
2354        .await?;
2355    let peer_id = request.peer_id.context("invalid peer id")?;
2356    session
2357        .peer
2358        .forward_send(session.connection_id, peer_id.into(), request)?;
2359    Ok(())
2360}
2361
2362/// Notify other participants that a buffer has been updated. This is
2363/// allowed for guests as long as the update is limited to selections.
2364async fn update_buffer(
2365    request: proto::UpdateBuffer,
2366    response: Response<proto::UpdateBuffer>,
2367    session: Session,
2368) -> Result<()> {
2369    let project_id = ProjectId::from_proto(request.project_id);
2370    let mut capability = Capability::ReadOnly;
2371
2372    for op in request.operations.iter() {
2373        match op.variant {
2374            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2375            Some(_) => capability = Capability::ReadWrite,
2376        }
2377    }
2378
2379    let host = {
2380        let guard = session
2381            .db()
2382            .await
2383            .connections_for_buffer_update(project_id, session.connection_id, capability)
2384            .await?;
2385
2386        let (host, guests) = &*guard;
2387
2388        broadcast(
2389            Some(session.connection_id),
2390            guests.clone(),
2391            |connection_id| {
2392                session
2393                    .peer
2394                    .forward_send(session.connection_id, connection_id, request.clone())
2395            },
2396        );
2397
2398        *host
2399    };
2400
2401    if host != session.connection_id {
2402        session
2403            .peer
2404            .forward_request(session.connection_id, host, request.clone())
2405            .await?;
2406    }
2407
2408    response.send(proto::Ack {})?;
2409    Ok(())
2410}
2411
2412async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2413    let project_id = ProjectId::from_proto(message.project_id);
2414
2415    let operation = message.operation.as_ref().context("invalid operation")?;
2416    let capability = match operation.variant.as_ref() {
2417        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2418            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2419                match buffer_op.variant {
2420                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2421                        Capability::ReadOnly
2422                    }
2423                    _ => Capability::ReadWrite,
2424                }
2425            } else {
2426                Capability::ReadWrite
2427            }
2428        }
2429        Some(_) => Capability::ReadWrite,
2430        None => Capability::ReadOnly,
2431    };
2432
2433    let guard = session
2434        .db()
2435        .await
2436        .connections_for_buffer_update(project_id, session.connection_id, capability)
2437        .await?;
2438
2439    let (host, guests) = &*guard;
2440
2441    broadcast(
2442        Some(session.connection_id),
2443        guests.iter().chain([host]).copied(),
2444        |connection_id| {
2445            session
2446                .peer
2447                .forward_send(session.connection_id, connection_id, message.clone())
2448        },
2449    );
2450
2451    Ok(())
2452}
2453
2454/// Notify other participants that a project has been updated.
2455async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2456    request: T,
2457    session: Session,
2458) -> Result<()> {
2459    let project_id = ProjectId::from_proto(request.remote_entity_id());
2460    let project_connection_ids = session
2461        .db()
2462        .await
2463        .project_connection_ids(project_id, session.connection_id, false)
2464        .await?;
2465
2466    broadcast(
2467        Some(session.connection_id),
2468        project_connection_ids.iter().copied(),
2469        |connection_id| {
2470            session
2471                .peer
2472                .forward_send(session.connection_id, connection_id, request.clone())
2473        },
2474    );
2475    Ok(())
2476}
2477
2478/// Start following another user in a call.
2479async fn follow(
2480    request: proto::Follow,
2481    response: Response<proto::Follow>,
2482    session: Session,
2483) -> Result<()> {
2484    let room_id = RoomId::from_proto(request.room_id);
2485    let project_id = request.project_id.map(ProjectId::from_proto);
2486    let leader_id = request.leader_id.context("invalid leader id")?.into();
2487    let follower_id = session.connection_id;
2488
2489    session
2490        .db()
2491        .await
2492        .check_room_participants(room_id, leader_id, session.connection_id)
2493        .await?;
2494
2495    let response_payload = session
2496        .peer
2497        .forward_request(session.connection_id, leader_id, request)
2498        .await?;
2499    response.send(response_payload)?;
2500
2501    if let Some(project_id) = project_id {
2502        let room = session
2503            .db()
2504            .await
2505            .follow(room_id, project_id, leader_id, follower_id)
2506            .await?;
2507        room_updated(&room, &session.peer);
2508    }
2509
2510    Ok(())
2511}
2512
2513/// Stop following another user in a call.
2514async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2515    let room_id = RoomId::from_proto(request.room_id);
2516    let project_id = request.project_id.map(ProjectId::from_proto);
2517    let leader_id = request.leader_id.context("invalid leader id")?.into();
2518    let follower_id = session.connection_id;
2519
2520    session
2521        .db()
2522        .await
2523        .check_room_participants(room_id, leader_id, session.connection_id)
2524        .await?;
2525
2526    session
2527        .peer
2528        .forward_send(session.connection_id, leader_id, request)?;
2529
2530    if let Some(project_id) = project_id {
2531        let room = session
2532            .db()
2533            .await
2534            .unfollow(room_id, project_id, leader_id, follower_id)
2535            .await?;
2536        room_updated(&room, &session.peer);
2537    }
2538
2539    Ok(())
2540}
2541
2542/// Notify everyone following you of your current location.
2543async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2544    let room_id = RoomId::from_proto(request.room_id);
2545    let database = session.db.lock().await;
2546
2547    let connection_ids = if let Some(project_id) = request.project_id {
2548        let project_id = ProjectId::from_proto(project_id);
2549        database
2550            .project_connection_ids(project_id, session.connection_id, true)
2551            .await?
2552    } else {
2553        database
2554            .room_connection_ids(room_id, session.connection_id)
2555            .await?
2556    };
2557
2558    // For now, don't send view update messages back to that view's current leader.
2559    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2560        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2561        _ => None,
2562    });
2563
2564    for connection_id in connection_ids.iter().cloned() {
2565        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2566            session
2567                .peer
2568                .forward_send(session.connection_id, connection_id, request.clone())?;
2569        }
2570    }
2571    Ok(())
2572}
2573
2574/// Get public data about users.
2575async fn get_users(
2576    request: proto::GetUsers,
2577    response: Response<proto::GetUsers>,
2578    session: Session,
2579) -> Result<()> {
2580    let user_ids = request
2581        .user_ids
2582        .into_iter()
2583        .map(UserId::from_proto)
2584        .collect();
2585    let users = session
2586        .db()
2587        .await
2588        .get_users_by_ids(user_ids)
2589        .await?
2590        .into_iter()
2591        .map(|user| proto::User {
2592            id: user.id.to_proto(),
2593            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2594            github_login: user.github_login,
2595            name: user.name,
2596        })
2597        .collect();
2598    response.send(proto::UsersResponse { users })?;
2599    Ok(())
2600}
2601
2602/// Search for users (to invite) buy Github login
2603async fn fuzzy_search_users(
2604    request: proto::FuzzySearchUsers,
2605    response: Response<proto::FuzzySearchUsers>,
2606    session: Session,
2607) -> Result<()> {
2608    let query = request.query;
2609    let users = match query.len() {
2610        0 => vec![],
2611        1 | 2 => session
2612            .db()
2613            .await
2614            .get_user_by_github_login(&query)
2615            .await?
2616            .into_iter()
2617            .collect(),
2618        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2619    };
2620    let users = users
2621        .into_iter()
2622        .filter(|user| user.id != session.user_id())
2623        .map(|user| proto::User {
2624            id: user.id.to_proto(),
2625            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2626            github_login: user.github_login,
2627            name: user.name,
2628        })
2629        .collect();
2630    response.send(proto::UsersResponse { users })?;
2631    Ok(())
2632}
2633
2634/// Send a contact request to another user.
2635async fn request_contact(
2636    request: proto::RequestContact,
2637    response: Response<proto::RequestContact>,
2638    session: Session,
2639) -> Result<()> {
2640    let requester_id = session.user_id();
2641    let responder_id = UserId::from_proto(request.responder_id);
2642    if requester_id == responder_id {
2643        return Err(anyhow!("cannot add yourself as a contact"))?;
2644    }
2645
2646    let notifications = session
2647        .db()
2648        .await
2649        .send_contact_request(requester_id, responder_id)
2650        .await?;
2651
2652    // Update outgoing contact requests of requester
2653    let mut update = proto::UpdateContacts::default();
2654    update.outgoing_requests.push(responder_id.to_proto());
2655    for connection_id in session
2656        .connection_pool()
2657        .await
2658        .user_connection_ids(requester_id)
2659    {
2660        session.peer.send(connection_id, update.clone())?;
2661    }
2662
2663    // Update incoming contact requests of responder
2664    let mut update = proto::UpdateContacts::default();
2665    update
2666        .incoming_requests
2667        .push(proto::IncomingContactRequest {
2668            requester_id: requester_id.to_proto(),
2669        });
2670    let connection_pool = session.connection_pool().await;
2671    for connection_id in connection_pool.user_connection_ids(responder_id) {
2672        session.peer.send(connection_id, update.clone())?;
2673    }
2674
2675    send_notifications(&connection_pool, &session.peer, notifications);
2676
2677    response.send(proto::Ack {})?;
2678    Ok(())
2679}
2680
2681/// Accept or decline a contact request
2682async fn respond_to_contact_request(
2683    request: proto::RespondToContactRequest,
2684    response: Response<proto::RespondToContactRequest>,
2685    session: Session,
2686) -> Result<()> {
2687    let responder_id = session.user_id();
2688    let requester_id = UserId::from_proto(request.requester_id);
2689    let db = session.db().await;
2690    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2691        db.dismiss_contact_notification(responder_id, requester_id)
2692            .await?;
2693    } else {
2694        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2695
2696        let notifications = db
2697            .respond_to_contact_request(responder_id, requester_id, accept)
2698            .await?;
2699        let requester_busy = db.is_user_busy(requester_id).await?;
2700        let responder_busy = db.is_user_busy(responder_id).await?;
2701
2702        let pool = session.connection_pool().await;
2703        // Update responder with new contact
2704        let mut update = proto::UpdateContacts::default();
2705        if accept {
2706            update
2707                .contacts
2708                .push(contact_for_user(requester_id, requester_busy, &pool));
2709        }
2710        update
2711            .remove_incoming_requests
2712            .push(requester_id.to_proto());
2713        for connection_id in pool.user_connection_ids(responder_id) {
2714            session.peer.send(connection_id, update.clone())?;
2715        }
2716
2717        // Update requester with new contact
2718        let mut update = proto::UpdateContacts::default();
2719        if accept {
2720            update
2721                .contacts
2722                .push(contact_for_user(responder_id, responder_busy, &pool));
2723        }
2724        update
2725            .remove_outgoing_requests
2726            .push(responder_id.to_proto());
2727
2728        for connection_id in pool.user_connection_ids(requester_id) {
2729            session.peer.send(connection_id, update.clone())?;
2730        }
2731
2732        send_notifications(&pool, &session.peer, notifications);
2733    }
2734
2735    response.send(proto::Ack {})?;
2736    Ok(())
2737}
2738
2739/// Remove a contact.
2740async fn remove_contact(
2741    request: proto::RemoveContact,
2742    response: Response<proto::RemoveContact>,
2743    session: Session,
2744) -> Result<()> {
2745    let requester_id = session.user_id();
2746    let responder_id = UserId::from_proto(request.user_id);
2747    let db = session.db().await;
2748    let (contact_accepted, deleted_notification_id) =
2749        db.remove_contact(requester_id, responder_id).await?;
2750
2751    let pool = session.connection_pool().await;
2752    // Update outgoing contact requests of requester
2753    let mut update = proto::UpdateContacts::default();
2754    if contact_accepted {
2755        update.remove_contacts.push(responder_id.to_proto());
2756    } else {
2757        update
2758            .remove_outgoing_requests
2759            .push(responder_id.to_proto());
2760    }
2761    for connection_id in pool.user_connection_ids(requester_id) {
2762        session.peer.send(connection_id, update.clone())?;
2763    }
2764
2765    // Update incoming contact requests of responder
2766    let mut update = proto::UpdateContacts::default();
2767    if contact_accepted {
2768        update.remove_contacts.push(requester_id.to_proto());
2769    } else {
2770        update
2771            .remove_incoming_requests
2772            .push(requester_id.to_proto());
2773    }
2774    for connection_id in pool.user_connection_ids(responder_id) {
2775        session.peer.send(connection_id, update.clone())?;
2776        if let Some(notification_id) = deleted_notification_id {
2777            session.peer.send(
2778                connection_id,
2779                proto::DeleteNotification {
2780                    notification_id: notification_id.to_proto(),
2781                },
2782            )?;
2783        }
2784    }
2785
2786    response.send(proto::Ack {})?;
2787    Ok(())
2788}
2789
2790fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2791    version.0.minor() < 139
2792}
2793
2794async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2795    if is_staff {
2796        return Ok(proto::Plan::ZedPro);
2797    }
2798
2799    let subscription = db.get_active_billing_subscription(user_id).await?;
2800    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2801
2802    let plan = if let Some(subscription_kind) = subscription_kind {
2803        match subscription_kind {
2804            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2805            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2806            SubscriptionKind::ZedFree => proto::Plan::Free,
2807        }
2808    } else {
2809        proto::Plan::Free
2810    };
2811
2812    Ok(plan)
2813}
2814
2815async fn make_update_user_plan_message(
2816    user: &User,
2817    is_staff: bool,
2818    db: &Arc<Database>,
2819    llm_db: Option<Arc<LlmDatabase>>,
2820) -> Result<proto::UpdateUserPlan> {
2821    let feature_flags = db.get_user_flags(user.id).await?;
2822    let plan = current_plan(db, user.id, is_staff).await?;
2823    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2824    let billing_preferences = db.get_billing_preferences(user.id).await?;
2825
2826    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2827        let subscription = db.get_active_billing_subscription(user.id).await?;
2828
2829        let subscription_period =
2830            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2831
2832        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2833            llm_db
2834                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2835                .await?
2836        } else {
2837            None
2838        };
2839
2840        (subscription_period, usage)
2841    } else {
2842        (None, None)
2843    };
2844
2845    let bypass_account_age_check = feature_flags
2846        .iter()
2847        .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2848    let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2849        && !bypass_account_age_check
2850        && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2851
2852    Ok(proto::UpdateUserPlan {
2853        plan: plan.into(),
2854        trial_started_at: billing_customer
2855            .as_ref()
2856            .and_then(|billing_customer| billing_customer.trial_started_at)
2857            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2858        is_usage_based_billing_enabled: if is_staff {
2859            Some(true)
2860        } else {
2861            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2862        },
2863        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2864            proto::SubscriptionPeriod {
2865                started_at: started_at.timestamp() as u64,
2866                ended_at: ended_at.timestamp() as u64,
2867            }
2868        }),
2869        account_too_young: Some(account_too_young),
2870        has_overdue_invoices: billing_customer
2871            .map(|billing_customer| billing_customer.has_overdue_invoices),
2872        usage: Some(
2873            usage
2874                .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2875                .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2876        ),
2877    })
2878}
2879
2880fn model_requests_limit(
2881    plan: cloud_llm_client::Plan,
2882    feature_flags: &Vec<String>,
2883) -> cloud_llm_client::UsageLimit {
2884    match plan.model_requests_limit() {
2885        cloud_llm_client::UsageLimit::Limited(limit) => {
2886            let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2887                && feature_flags
2888                    .iter()
2889                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2890            {
2891                1_000
2892            } else {
2893                limit
2894            };
2895
2896            cloud_llm_client::UsageLimit::Limited(limit)
2897        }
2898        cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2899    }
2900}
2901
2902fn subscription_usage_to_proto(
2903    plan: proto::Plan,
2904    usage: crate::llm::db::subscription_usage::Model,
2905    feature_flags: &Vec<String>,
2906) -> proto::SubscriptionUsage {
2907    let plan = match plan {
2908        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2909        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2910        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2911    };
2912
2913    proto::SubscriptionUsage {
2914        model_requests_usage_amount: usage.model_requests as u32,
2915        model_requests_usage_limit: Some(proto::UsageLimit {
2916            variant: Some(match model_requests_limit(plan, feature_flags) {
2917                cloud_llm_client::UsageLimit::Limited(limit) => {
2918                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2919                        limit: limit as u32,
2920                    })
2921                }
2922                cloud_llm_client::UsageLimit::Unlimited => {
2923                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2924                }
2925            }),
2926        }),
2927        edit_predictions_usage_amount: usage.edit_predictions as u32,
2928        edit_predictions_usage_limit: Some(proto::UsageLimit {
2929            variant: Some(match plan.edit_predictions_limit() {
2930                cloud_llm_client::UsageLimit::Limited(limit) => {
2931                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2932                        limit: limit as u32,
2933                    })
2934                }
2935                cloud_llm_client::UsageLimit::Unlimited => {
2936                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2937                }
2938            }),
2939        }),
2940    }
2941}
2942
2943fn make_default_subscription_usage(
2944    plan: proto::Plan,
2945    feature_flags: &Vec<String>,
2946) -> proto::SubscriptionUsage {
2947    let plan = match plan {
2948        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2949        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2950        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2951    };
2952
2953    proto::SubscriptionUsage {
2954        model_requests_usage_amount: 0,
2955        model_requests_usage_limit: Some(proto::UsageLimit {
2956            variant: Some(match model_requests_limit(plan, feature_flags) {
2957                cloud_llm_client::UsageLimit::Limited(limit) => {
2958                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2959                        limit: limit as u32,
2960                    })
2961                }
2962                cloud_llm_client::UsageLimit::Unlimited => {
2963                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2964                }
2965            }),
2966        }),
2967        edit_predictions_usage_amount: 0,
2968        edit_predictions_usage_limit: Some(proto::UsageLimit {
2969            variant: Some(match plan.edit_predictions_limit() {
2970                cloud_llm_client::UsageLimit::Limited(limit) => {
2971                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2972                        limit: limit as u32,
2973                    })
2974                }
2975                cloud_llm_client::UsageLimit::Unlimited => {
2976                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2977                }
2978            }),
2979        }),
2980    }
2981}
2982
2983async fn update_user_plan(session: &Session) -> Result<()> {
2984    let db = session.db().await;
2985
2986    let update_user_plan = make_update_user_plan_message(
2987        session.principal.user(),
2988        session.is_staff(),
2989        &db.0,
2990        session.app_state.llm_db.clone(),
2991    )
2992    .await?;
2993
2994    session
2995        .peer
2996        .send(session.connection_id, update_user_plan)
2997        .trace_err();
2998
2999    Ok(())
3000}
3001
3002async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3003    subscribe_user_to_channels(session.user_id(), &session).await?;
3004    Ok(())
3005}
3006
3007async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3008    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3009    let mut pool = session.connection_pool().await;
3010    for membership in &channels_for_user.channel_memberships {
3011        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3012    }
3013    session.peer.send(
3014        session.connection_id,
3015        build_update_user_channels(&channels_for_user),
3016    )?;
3017    session.peer.send(
3018        session.connection_id,
3019        build_channels_update(channels_for_user),
3020    )?;
3021    Ok(())
3022}
3023
3024/// Creates a new channel.
3025async fn create_channel(
3026    request: proto::CreateChannel,
3027    response: Response<proto::CreateChannel>,
3028    session: Session,
3029) -> Result<()> {
3030    let db = session.db().await;
3031
3032    let parent_id = request.parent_id.map(ChannelId::from_proto);
3033    let (channel, membership) = db
3034        .create_channel(&request.name, parent_id, session.user_id())
3035        .await?;
3036
3037    let root_id = channel.root_id();
3038    let channel = Channel::from_model(channel);
3039
3040    response.send(proto::CreateChannelResponse {
3041        channel: Some(channel.to_proto()),
3042        parent_id: request.parent_id,
3043    })?;
3044
3045    let mut connection_pool = session.connection_pool().await;
3046    if let Some(membership) = membership {
3047        connection_pool.subscribe_to_channel(
3048            membership.user_id,
3049            membership.channel_id,
3050            membership.role,
3051        );
3052        let update = proto::UpdateUserChannels {
3053            channel_memberships: vec![proto::ChannelMembership {
3054                channel_id: membership.channel_id.to_proto(),
3055                role: membership.role.into(),
3056            }],
3057            ..Default::default()
3058        };
3059        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3060            session.peer.send(connection_id, update.clone())?;
3061        }
3062    }
3063
3064    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3065        if !role.can_see_channel(channel.visibility) {
3066            continue;
3067        }
3068
3069        let update = proto::UpdateChannels {
3070            channels: vec![channel.to_proto()],
3071            ..Default::default()
3072        };
3073        session.peer.send(connection_id, update.clone())?;
3074    }
3075
3076    Ok(())
3077}
3078
3079/// Delete a channel
3080async fn delete_channel(
3081    request: proto::DeleteChannel,
3082    response: Response<proto::DeleteChannel>,
3083    session: Session,
3084) -> Result<()> {
3085    let db = session.db().await;
3086
3087    let channel_id = request.channel_id;
3088    let (root_channel, removed_channels) = db
3089        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3090        .await?;
3091    response.send(proto::Ack {})?;
3092
3093    // Notify members of removed channels
3094    let mut update = proto::UpdateChannels::default();
3095    update
3096        .delete_channels
3097        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3098
3099    let connection_pool = session.connection_pool().await;
3100    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3101        session.peer.send(connection_id, update.clone())?;
3102    }
3103
3104    Ok(())
3105}
3106
3107/// Invite someone to join a channel.
3108async fn invite_channel_member(
3109    request: proto::InviteChannelMember,
3110    response: Response<proto::InviteChannelMember>,
3111    session: Session,
3112) -> Result<()> {
3113    let db = session.db().await;
3114    let channel_id = ChannelId::from_proto(request.channel_id);
3115    let invitee_id = UserId::from_proto(request.user_id);
3116    let InviteMemberResult {
3117        channel,
3118        notifications,
3119    } = db
3120        .invite_channel_member(
3121            channel_id,
3122            invitee_id,
3123            session.user_id(),
3124            request.role().into(),
3125        )
3126        .await?;
3127
3128    let update = proto::UpdateChannels {
3129        channel_invitations: vec![channel.to_proto()],
3130        ..Default::default()
3131    };
3132
3133    let connection_pool = session.connection_pool().await;
3134    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3135        session.peer.send(connection_id, update.clone())?;
3136    }
3137
3138    send_notifications(&connection_pool, &session.peer, notifications);
3139
3140    response.send(proto::Ack {})?;
3141    Ok(())
3142}
3143
3144/// remove someone from a channel
3145async fn remove_channel_member(
3146    request: proto::RemoveChannelMember,
3147    response: Response<proto::RemoveChannelMember>,
3148    session: Session,
3149) -> Result<()> {
3150    let db = session.db().await;
3151    let channel_id = ChannelId::from_proto(request.channel_id);
3152    let member_id = UserId::from_proto(request.user_id);
3153
3154    let RemoveChannelMemberResult {
3155        membership_update,
3156        notification_id,
3157    } = db
3158        .remove_channel_member(channel_id, member_id, session.user_id())
3159        .await?;
3160
3161    let mut connection_pool = session.connection_pool().await;
3162    notify_membership_updated(
3163        &mut connection_pool,
3164        membership_update,
3165        member_id,
3166        &session.peer,
3167    );
3168    for connection_id in connection_pool.user_connection_ids(member_id) {
3169        if let Some(notification_id) = notification_id {
3170            session
3171                .peer
3172                .send(
3173                    connection_id,
3174                    proto::DeleteNotification {
3175                        notification_id: notification_id.to_proto(),
3176                    },
3177                )
3178                .trace_err();
3179        }
3180    }
3181
3182    response.send(proto::Ack {})?;
3183    Ok(())
3184}
3185
3186/// Toggle the channel between public and private.
3187/// Care is taken to maintain the invariant that public channels only descend from public channels,
3188/// (though members-only channels can appear at any point in the hierarchy).
3189async fn set_channel_visibility(
3190    request: proto::SetChannelVisibility,
3191    response: Response<proto::SetChannelVisibility>,
3192    session: Session,
3193) -> Result<()> {
3194    let db = session.db().await;
3195    let channel_id = ChannelId::from_proto(request.channel_id);
3196    let visibility = request.visibility().into();
3197
3198    let channel_model = db
3199        .set_channel_visibility(channel_id, visibility, session.user_id())
3200        .await?;
3201    let root_id = channel_model.root_id();
3202    let channel = Channel::from_model(channel_model);
3203
3204    let mut connection_pool = session.connection_pool().await;
3205    for (user_id, role) in connection_pool
3206        .channel_user_ids(root_id)
3207        .collect::<Vec<_>>()
3208        .into_iter()
3209    {
3210        let update = if role.can_see_channel(channel.visibility) {
3211            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3212            proto::UpdateChannels {
3213                channels: vec![channel.to_proto()],
3214                ..Default::default()
3215            }
3216        } else {
3217            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3218            proto::UpdateChannels {
3219                delete_channels: vec![channel.id.to_proto()],
3220                ..Default::default()
3221            }
3222        };
3223
3224        for connection_id in connection_pool.user_connection_ids(user_id) {
3225            session.peer.send(connection_id, update.clone())?;
3226        }
3227    }
3228
3229    response.send(proto::Ack {})?;
3230    Ok(())
3231}
3232
3233/// Alter the role for a user in the channel.
3234async fn set_channel_member_role(
3235    request: proto::SetChannelMemberRole,
3236    response: Response<proto::SetChannelMemberRole>,
3237    session: Session,
3238) -> Result<()> {
3239    let db = session.db().await;
3240    let channel_id = ChannelId::from_proto(request.channel_id);
3241    let member_id = UserId::from_proto(request.user_id);
3242    let result = db
3243        .set_channel_member_role(
3244            channel_id,
3245            session.user_id(),
3246            member_id,
3247            request.role().into(),
3248        )
3249        .await?;
3250
3251    match result {
3252        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3253            let mut connection_pool = session.connection_pool().await;
3254            notify_membership_updated(
3255                &mut connection_pool,
3256                membership_update,
3257                member_id,
3258                &session.peer,
3259            )
3260        }
3261        db::SetMemberRoleResult::InviteUpdated(channel) => {
3262            let update = proto::UpdateChannels {
3263                channel_invitations: vec![channel.to_proto()],
3264                ..Default::default()
3265            };
3266
3267            for connection_id in session
3268                .connection_pool()
3269                .await
3270                .user_connection_ids(member_id)
3271            {
3272                session.peer.send(connection_id, update.clone())?;
3273            }
3274        }
3275    }
3276
3277    response.send(proto::Ack {})?;
3278    Ok(())
3279}
3280
3281/// Change the name of a channel
3282async fn rename_channel(
3283    request: proto::RenameChannel,
3284    response: Response<proto::RenameChannel>,
3285    session: Session,
3286) -> Result<()> {
3287    let db = session.db().await;
3288    let channel_id = ChannelId::from_proto(request.channel_id);
3289    let channel_model = db
3290        .rename_channel(channel_id, session.user_id(), &request.name)
3291        .await?;
3292    let root_id = channel_model.root_id();
3293    let channel = Channel::from_model(channel_model);
3294
3295    response.send(proto::RenameChannelResponse {
3296        channel: Some(channel.to_proto()),
3297    })?;
3298
3299    let connection_pool = session.connection_pool().await;
3300    let update = proto::UpdateChannels {
3301        channels: vec![channel.to_proto()],
3302        ..Default::default()
3303    };
3304    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3305        if role.can_see_channel(channel.visibility) {
3306            session.peer.send(connection_id, update.clone())?;
3307        }
3308    }
3309
3310    Ok(())
3311}
3312
3313/// Move a channel to a new parent.
3314async fn move_channel(
3315    request: proto::MoveChannel,
3316    response: Response<proto::MoveChannel>,
3317    session: Session,
3318) -> Result<()> {
3319    let channel_id = ChannelId::from_proto(request.channel_id);
3320    let to = ChannelId::from_proto(request.to);
3321
3322    let (root_id, channels) = session
3323        .db()
3324        .await
3325        .move_channel(channel_id, to, session.user_id())
3326        .await?;
3327
3328    let connection_pool = session.connection_pool().await;
3329    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3330        let channels = channels
3331            .iter()
3332            .filter_map(|channel| {
3333                if role.can_see_channel(channel.visibility) {
3334                    Some(channel.to_proto())
3335                } else {
3336                    None
3337                }
3338            })
3339            .collect::<Vec<_>>();
3340        if channels.is_empty() {
3341            continue;
3342        }
3343
3344        let update = proto::UpdateChannels {
3345            channels,
3346            ..Default::default()
3347        };
3348
3349        session.peer.send(connection_id, update.clone())?;
3350    }
3351
3352    response.send(Ack {})?;
3353    Ok(())
3354}
3355
3356async fn reorder_channel(
3357    request: proto::ReorderChannel,
3358    response: Response<proto::ReorderChannel>,
3359    session: Session,
3360) -> Result<()> {
3361    let channel_id = ChannelId::from_proto(request.channel_id);
3362    let direction = request.direction();
3363
3364    let updated_channels = session
3365        .db()
3366        .await
3367        .reorder_channel(channel_id, direction, session.user_id())
3368        .await?;
3369
3370    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3371        let connection_pool = session.connection_pool().await;
3372        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3373            let channels = updated_channels
3374                .iter()
3375                .filter_map(|channel| {
3376                    if role.can_see_channel(channel.visibility) {
3377                        Some(channel.to_proto())
3378                    } else {
3379                        None
3380                    }
3381                })
3382                .collect::<Vec<_>>();
3383
3384            if channels.is_empty() {
3385                continue;
3386            }
3387
3388            let update = proto::UpdateChannels {
3389                channels,
3390                ..Default::default()
3391            };
3392
3393            session.peer.send(connection_id, update.clone())?;
3394        }
3395    }
3396
3397    response.send(Ack {})?;
3398    Ok(())
3399}
3400
3401/// Get the list of channel members
3402async fn get_channel_members(
3403    request: proto::GetChannelMembers,
3404    response: Response<proto::GetChannelMembers>,
3405    session: Session,
3406) -> Result<()> {
3407    let db = session.db().await;
3408    let channel_id = ChannelId::from_proto(request.channel_id);
3409    let limit = if request.limit == 0 {
3410        u16::MAX as u64
3411    } else {
3412        request.limit
3413    };
3414    let (members, users) = db
3415        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3416        .await?;
3417    response.send(proto::GetChannelMembersResponse { members, users })?;
3418    Ok(())
3419}
3420
3421/// Accept or decline a channel invitation.
3422async fn respond_to_channel_invite(
3423    request: proto::RespondToChannelInvite,
3424    response: Response<proto::RespondToChannelInvite>,
3425    session: Session,
3426) -> Result<()> {
3427    let db = session.db().await;
3428    let channel_id = ChannelId::from_proto(request.channel_id);
3429    let RespondToChannelInvite {
3430        membership_update,
3431        notifications,
3432    } = db
3433        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3434        .await?;
3435
3436    let mut connection_pool = session.connection_pool().await;
3437    if let Some(membership_update) = membership_update {
3438        notify_membership_updated(
3439            &mut connection_pool,
3440            membership_update,
3441            session.user_id(),
3442            &session.peer,
3443        );
3444    } else {
3445        let update = proto::UpdateChannels {
3446            remove_channel_invitations: vec![channel_id.to_proto()],
3447            ..Default::default()
3448        };
3449
3450        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3451            session.peer.send(connection_id, update.clone())?;
3452        }
3453    };
3454
3455    send_notifications(&connection_pool, &session.peer, notifications);
3456
3457    response.send(proto::Ack {})?;
3458
3459    Ok(())
3460}
3461
3462/// Join the channels' room
3463async fn join_channel(
3464    request: proto::JoinChannel,
3465    response: Response<proto::JoinChannel>,
3466    session: Session,
3467) -> Result<()> {
3468    let channel_id = ChannelId::from_proto(request.channel_id);
3469    join_channel_internal(channel_id, Box::new(response), session).await
3470}
3471
3472trait JoinChannelInternalResponse {
3473    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3474}
3475impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3476    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3477        Response::<proto::JoinChannel>::send(self, result)
3478    }
3479}
3480impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3481    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3482        Response::<proto::JoinRoom>::send(self, result)
3483    }
3484}
3485
3486async fn join_channel_internal(
3487    channel_id: ChannelId,
3488    response: Box<impl JoinChannelInternalResponse>,
3489    session: Session,
3490) -> Result<()> {
3491    let joined_room = {
3492        let mut db = session.db().await;
3493        // If zed quits without leaving the room, and the user re-opens zed before the
3494        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3495        // room they were in.
3496        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3497            tracing::info!(
3498                stale_connection_id = %connection,
3499                "cleaning up stale connection",
3500            );
3501            drop(db);
3502            leave_room_for_session(&session, connection).await?;
3503            db = session.db().await;
3504        }
3505
3506        let (joined_room, membership_updated, role) = db
3507            .join_channel(channel_id, session.user_id(), session.connection_id)
3508            .await?;
3509
3510        let live_kit_connection_info =
3511            session
3512                .app_state
3513                .livekit_client
3514                .as_ref()
3515                .and_then(|live_kit| {
3516                    let (can_publish, token) = if role == ChannelRole::Guest {
3517                        (
3518                            false,
3519                            live_kit
3520                                .guest_token(
3521                                    &joined_room.room.livekit_room,
3522                                    &session.user_id().to_string(),
3523                                )
3524                                .trace_err()?,
3525                        )
3526                    } else {
3527                        (
3528                            true,
3529                            live_kit
3530                                .room_token(
3531                                    &joined_room.room.livekit_room,
3532                                    &session.user_id().to_string(),
3533                                )
3534                                .trace_err()?,
3535                        )
3536                    };
3537
3538                    Some(LiveKitConnectionInfo {
3539                        server_url: live_kit.url().into(),
3540                        token,
3541                        can_publish,
3542                    })
3543                });
3544
3545        response.send(proto::JoinRoomResponse {
3546            room: Some(joined_room.room.clone()),
3547            channel_id: joined_room
3548                .channel
3549                .as_ref()
3550                .map(|channel| channel.id.to_proto()),
3551            live_kit_connection_info,
3552        })?;
3553
3554        let mut connection_pool = session.connection_pool().await;
3555        if let Some(membership_updated) = membership_updated {
3556            notify_membership_updated(
3557                &mut connection_pool,
3558                membership_updated,
3559                session.user_id(),
3560                &session.peer,
3561            );
3562        }
3563
3564        room_updated(&joined_room.room, &session.peer);
3565
3566        joined_room
3567    };
3568
3569    channel_updated(
3570        &joined_room.channel.context("channel not returned")?,
3571        &joined_room.room,
3572        &session.peer,
3573        &*session.connection_pool().await,
3574    );
3575
3576    update_user_contacts(session.user_id(), &session).await?;
3577    Ok(())
3578}
3579
3580/// Start editing the channel notes
3581async fn join_channel_buffer(
3582    request: proto::JoinChannelBuffer,
3583    response: Response<proto::JoinChannelBuffer>,
3584    session: Session,
3585) -> Result<()> {
3586    let db = session.db().await;
3587    let channel_id = ChannelId::from_proto(request.channel_id);
3588
3589    let open_response = db
3590        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3591        .await?;
3592
3593    let collaborators = open_response.collaborators.clone();
3594    response.send(open_response)?;
3595
3596    let update = UpdateChannelBufferCollaborators {
3597        channel_id: channel_id.to_proto(),
3598        collaborators: collaborators.clone(),
3599    };
3600    channel_buffer_updated(
3601        session.connection_id,
3602        collaborators
3603            .iter()
3604            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3605        &update,
3606        &session.peer,
3607    );
3608
3609    Ok(())
3610}
3611
3612/// Edit the channel notes
3613async fn update_channel_buffer(
3614    request: proto::UpdateChannelBuffer,
3615    session: Session,
3616) -> Result<()> {
3617    let db = session.db().await;
3618    let channel_id = ChannelId::from_proto(request.channel_id);
3619
3620    let (collaborators, epoch, version) = db
3621        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3622        .await?;
3623
3624    channel_buffer_updated(
3625        session.connection_id,
3626        collaborators.clone(),
3627        &proto::UpdateChannelBuffer {
3628            channel_id: channel_id.to_proto(),
3629            operations: request.operations,
3630        },
3631        &session.peer,
3632    );
3633
3634    let pool = &*session.connection_pool().await;
3635
3636    let non_collaborators =
3637        pool.channel_connection_ids(channel_id)
3638            .filter_map(|(connection_id, _)| {
3639                if collaborators.contains(&connection_id) {
3640                    None
3641                } else {
3642                    Some(connection_id)
3643                }
3644            });
3645
3646    broadcast(None, non_collaborators, |peer_id| {
3647        session.peer.send(
3648            peer_id,
3649            proto::UpdateChannels {
3650                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3651                    channel_id: channel_id.to_proto(),
3652                    epoch: epoch as u64,
3653                    version: version.clone(),
3654                }],
3655                ..Default::default()
3656            },
3657        )
3658    });
3659
3660    Ok(())
3661}
3662
3663/// Rejoin the channel notes after a connection blip
3664async fn rejoin_channel_buffers(
3665    request: proto::RejoinChannelBuffers,
3666    response: Response<proto::RejoinChannelBuffers>,
3667    session: Session,
3668) -> Result<()> {
3669    let db = session.db().await;
3670    let buffers = db
3671        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3672        .await?;
3673
3674    for rejoined_buffer in &buffers {
3675        let collaborators_to_notify = rejoined_buffer
3676            .buffer
3677            .collaborators
3678            .iter()
3679            .filter_map(|c| Some(c.peer_id?.into()));
3680        channel_buffer_updated(
3681            session.connection_id,
3682            collaborators_to_notify,
3683            &proto::UpdateChannelBufferCollaborators {
3684                channel_id: rejoined_buffer.buffer.channel_id,
3685                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3686            },
3687            &session.peer,
3688        );
3689    }
3690
3691    response.send(proto::RejoinChannelBuffersResponse {
3692        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3693    })?;
3694
3695    Ok(())
3696}
3697
3698/// Stop editing the channel notes
3699async fn leave_channel_buffer(
3700    request: proto::LeaveChannelBuffer,
3701    response: Response<proto::LeaveChannelBuffer>,
3702    session: Session,
3703) -> Result<()> {
3704    let db = session.db().await;
3705    let channel_id = ChannelId::from_proto(request.channel_id);
3706
3707    let left_buffer = db
3708        .leave_channel_buffer(channel_id, session.connection_id)
3709        .await?;
3710
3711    response.send(Ack {})?;
3712
3713    channel_buffer_updated(
3714        session.connection_id,
3715        left_buffer.connections,
3716        &proto::UpdateChannelBufferCollaborators {
3717            channel_id: channel_id.to_proto(),
3718            collaborators: left_buffer.collaborators,
3719        },
3720        &session.peer,
3721    );
3722
3723    Ok(())
3724}
3725
3726fn channel_buffer_updated<T: EnvelopedMessage>(
3727    sender_id: ConnectionId,
3728    collaborators: impl IntoIterator<Item = ConnectionId>,
3729    message: &T,
3730    peer: &Peer,
3731) {
3732    broadcast(Some(sender_id), collaborators, |peer_id| {
3733        peer.send(peer_id, message.clone())
3734    });
3735}
3736
3737fn send_notifications(
3738    connection_pool: &ConnectionPool,
3739    peer: &Peer,
3740    notifications: db::NotificationBatch,
3741) {
3742    for (user_id, notification) in notifications {
3743        for connection_id in connection_pool.user_connection_ids(user_id) {
3744            if let Err(error) = peer.send(
3745                connection_id,
3746                proto::AddNotification {
3747                    notification: Some(notification.clone()),
3748                },
3749            ) {
3750                tracing::error!(
3751                    "failed to send notification to {:?} {}",
3752                    connection_id,
3753                    error
3754                );
3755            }
3756        }
3757    }
3758}
3759
3760/// Send a message to the channel
3761async fn send_channel_message(
3762    request: proto::SendChannelMessage,
3763    response: Response<proto::SendChannelMessage>,
3764    session: Session,
3765) -> Result<()> {
3766    // Validate the message body.
3767    let body = request.body.trim().to_string();
3768    if body.len() > MAX_MESSAGE_LEN {
3769        return Err(anyhow!("message is too long"))?;
3770    }
3771    if body.is_empty() {
3772        return Err(anyhow!("message can't be blank"))?;
3773    }
3774
3775    // TODO: adjust mentions if body is trimmed
3776
3777    let timestamp = OffsetDateTime::now_utc();
3778    let nonce = request.nonce.context("nonce can't be blank")?;
3779
3780    let channel_id = ChannelId::from_proto(request.channel_id);
3781    let CreatedChannelMessage {
3782        message_id,
3783        participant_connection_ids,
3784        notifications,
3785    } = session
3786        .db()
3787        .await
3788        .create_channel_message(
3789            channel_id,
3790            session.user_id(),
3791            &body,
3792            &request.mentions,
3793            timestamp,
3794            nonce.clone().into(),
3795            request.reply_to_message_id.map(MessageId::from_proto),
3796        )
3797        .await?;
3798
3799    let message = proto::ChannelMessage {
3800        sender_id: session.user_id().to_proto(),
3801        id: message_id.to_proto(),
3802        body,
3803        mentions: request.mentions,
3804        timestamp: timestamp.unix_timestamp() as u64,
3805        nonce: Some(nonce),
3806        reply_to_message_id: request.reply_to_message_id,
3807        edited_at: None,
3808    };
3809    broadcast(
3810        Some(session.connection_id),
3811        participant_connection_ids.clone(),
3812        |connection| {
3813            session.peer.send(
3814                connection,
3815                proto::ChannelMessageSent {
3816                    channel_id: channel_id.to_proto(),
3817                    message: Some(message.clone()),
3818                },
3819            )
3820        },
3821    );
3822    response.send(proto::SendChannelMessageResponse {
3823        message: Some(message),
3824    })?;
3825
3826    let pool = &*session.connection_pool().await;
3827    let non_participants =
3828        pool.channel_connection_ids(channel_id)
3829            .filter_map(|(connection_id, _)| {
3830                if participant_connection_ids.contains(&connection_id) {
3831                    None
3832                } else {
3833                    Some(connection_id)
3834                }
3835            });
3836    broadcast(None, non_participants, |peer_id| {
3837        session.peer.send(
3838            peer_id,
3839            proto::UpdateChannels {
3840                latest_channel_message_ids: vec![proto::ChannelMessageId {
3841                    channel_id: channel_id.to_proto(),
3842                    message_id: message_id.to_proto(),
3843                }],
3844                ..Default::default()
3845            },
3846        )
3847    });
3848    send_notifications(pool, &session.peer, notifications);
3849
3850    Ok(())
3851}
3852
3853/// Delete a channel message
3854async fn remove_channel_message(
3855    request: proto::RemoveChannelMessage,
3856    response: Response<proto::RemoveChannelMessage>,
3857    session: Session,
3858) -> Result<()> {
3859    let channel_id = ChannelId::from_proto(request.channel_id);
3860    let message_id = MessageId::from_proto(request.message_id);
3861    let (connection_ids, existing_notification_ids) = session
3862        .db()
3863        .await
3864        .remove_channel_message(channel_id, message_id, session.user_id())
3865        .await?;
3866
3867    broadcast(
3868        Some(session.connection_id),
3869        connection_ids,
3870        move |connection| {
3871            session.peer.send(connection, request.clone())?;
3872
3873            for notification_id in &existing_notification_ids {
3874                session.peer.send(
3875                    connection,
3876                    proto::DeleteNotification {
3877                        notification_id: (*notification_id).to_proto(),
3878                    },
3879                )?;
3880            }
3881
3882            Ok(())
3883        },
3884    );
3885    response.send(proto::Ack {})?;
3886    Ok(())
3887}
3888
3889async fn update_channel_message(
3890    request: proto::UpdateChannelMessage,
3891    response: Response<proto::UpdateChannelMessage>,
3892    session: Session,
3893) -> Result<()> {
3894    let channel_id = ChannelId::from_proto(request.channel_id);
3895    let message_id = MessageId::from_proto(request.message_id);
3896    let updated_at = OffsetDateTime::now_utc();
3897    let UpdatedChannelMessage {
3898        message_id,
3899        participant_connection_ids,
3900        notifications,
3901        reply_to_message_id,
3902        timestamp,
3903        deleted_mention_notification_ids,
3904        updated_mention_notifications,
3905    } = session
3906        .db()
3907        .await
3908        .update_channel_message(
3909            channel_id,
3910            message_id,
3911            session.user_id(),
3912            request.body.as_str(),
3913            &request.mentions,
3914            updated_at,
3915        )
3916        .await?;
3917
3918    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3919
3920    let message = proto::ChannelMessage {
3921        sender_id: session.user_id().to_proto(),
3922        id: message_id.to_proto(),
3923        body: request.body.clone(),
3924        mentions: request.mentions.clone(),
3925        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3926        nonce: Some(nonce),
3927        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3928        edited_at: Some(updated_at.unix_timestamp() as u64),
3929    };
3930
3931    response.send(proto::Ack {})?;
3932
3933    let pool = &*session.connection_pool().await;
3934    broadcast(
3935        Some(session.connection_id),
3936        participant_connection_ids,
3937        |connection| {
3938            session.peer.send(
3939                connection,
3940                proto::ChannelMessageUpdate {
3941                    channel_id: channel_id.to_proto(),
3942                    message: Some(message.clone()),
3943                },
3944            )?;
3945
3946            for notification_id in &deleted_mention_notification_ids {
3947                session.peer.send(
3948                    connection,
3949                    proto::DeleteNotification {
3950                        notification_id: (*notification_id).to_proto(),
3951                    },
3952                )?;
3953            }
3954
3955            for notification in &updated_mention_notifications {
3956                session.peer.send(
3957                    connection,
3958                    proto::UpdateNotification {
3959                        notification: Some(notification.clone()),
3960                    },
3961                )?;
3962            }
3963
3964            Ok(())
3965        },
3966    );
3967
3968    send_notifications(pool, &session.peer, notifications);
3969
3970    Ok(())
3971}
3972
3973/// Mark a channel message as read
3974async fn acknowledge_channel_message(
3975    request: proto::AckChannelMessage,
3976    session: Session,
3977) -> Result<()> {
3978    let channel_id = ChannelId::from_proto(request.channel_id);
3979    let message_id = MessageId::from_proto(request.message_id);
3980    let notifications = session
3981        .db()
3982        .await
3983        .observe_channel_message(channel_id, session.user_id(), message_id)
3984        .await?;
3985    send_notifications(
3986        &*session.connection_pool().await,
3987        &session.peer,
3988        notifications,
3989    );
3990    Ok(())
3991}
3992
3993/// Mark a buffer version as synced
3994async fn acknowledge_buffer_version(
3995    request: proto::AckBufferOperation,
3996    session: Session,
3997) -> Result<()> {
3998    let buffer_id = BufferId::from_proto(request.buffer_id);
3999    session
4000        .db()
4001        .await
4002        .observe_buffer_version(
4003            buffer_id,
4004            session.user_id(),
4005            request.epoch as i32,
4006            &request.version,
4007        )
4008        .await?;
4009    Ok(())
4010}
4011
4012/// Get a Supermaven API key for the user
4013async fn get_supermaven_api_key(
4014    _request: proto::GetSupermavenApiKey,
4015    response: Response<proto::GetSupermavenApiKey>,
4016    session: Session,
4017) -> Result<()> {
4018    let user_id: String = session.user_id().to_string();
4019    if !session.is_staff() {
4020        return Err(anyhow!("supermaven not enabled for this account"))?;
4021    }
4022
4023    let email = session.email().context("user must have an email")?;
4024
4025    let supermaven_admin_api = session
4026        .supermaven_client
4027        .as_ref()
4028        .context("supermaven not configured")?;
4029
4030    let result = supermaven_admin_api
4031        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4032        .await?;
4033
4034    response.send(proto::GetSupermavenApiKeyResponse {
4035        api_key: result.api_key,
4036    })?;
4037
4038    Ok(())
4039}
4040
4041/// Start receiving chat updates for a channel
4042async fn join_channel_chat(
4043    request: proto::JoinChannelChat,
4044    response: Response<proto::JoinChannelChat>,
4045    session: Session,
4046) -> Result<()> {
4047    let channel_id = ChannelId::from_proto(request.channel_id);
4048
4049    let db = session.db().await;
4050    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4051        .await?;
4052    let messages = db
4053        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4054        .await?;
4055    response.send(proto::JoinChannelChatResponse {
4056        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4057        messages,
4058    })?;
4059    Ok(())
4060}
4061
4062/// Stop receiving chat updates for a channel
4063async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4064    let channel_id = ChannelId::from_proto(request.channel_id);
4065    session
4066        .db()
4067        .await
4068        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4069        .await?;
4070    Ok(())
4071}
4072
4073/// Retrieve the chat history for a channel
4074async fn get_channel_messages(
4075    request: proto::GetChannelMessages,
4076    response: Response<proto::GetChannelMessages>,
4077    session: Session,
4078) -> Result<()> {
4079    let channel_id = ChannelId::from_proto(request.channel_id);
4080    let messages = session
4081        .db()
4082        .await
4083        .get_channel_messages(
4084            channel_id,
4085            session.user_id(),
4086            MESSAGE_COUNT_PER_PAGE,
4087            Some(MessageId::from_proto(request.before_message_id)),
4088        )
4089        .await?;
4090    response.send(proto::GetChannelMessagesResponse {
4091        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4092        messages,
4093    })?;
4094    Ok(())
4095}
4096
4097/// Retrieve specific chat messages
4098async fn get_channel_messages_by_id(
4099    request: proto::GetChannelMessagesById,
4100    response: Response<proto::GetChannelMessagesById>,
4101    session: Session,
4102) -> Result<()> {
4103    let message_ids = request
4104        .message_ids
4105        .iter()
4106        .map(|id| MessageId::from_proto(*id))
4107        .collect::<Vec<_>>();
4108    let messages = session
4109        .db()
4110        .await
4111        .get_channel_messages_by_id(session.user_id(), &message_ids)
4112        .await?;
4113    response.send(proto::GetChannelMessagesResponse {
4114        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4115        messages,
4116    })?;
4117    Ok(())
4118}
4119
4120/// Retrieve the current users notifications
4121async fn get_notifications(
4122    request: proto::GetNotifications,
4123    response: Response<proto::GetNotifications>,
4124    session: Session,
4125) -> Result<()> {
4126    let notifications = session
4127        .db()
4128        .await
4129        .get_notifications(
4130            session.user_id(),
4131            NOTIFICATION_COUNT_PER_PAGE,
4132            request.before_id.map(db::NotificationId::from_proto),
4133        )
4134        .await?;
4135    response.send(proto::GetNotificationsResponse {
4136        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4137        notifications,
4138    })?;
4139    Ok(())
4140}
4141
4142/// Mark notifications as read
4143async fn mark_notification_as_read(
4144    request: proto::MarkNotificationRead,
4145    response: Response<proto::MarkNotificationRead>,
4146    session: Session,
4147) -> Result<()> {
4148    let database = &session.db().await;
4149    let notifications = database
4150        .mark_notification_as_read_by_id(
4151            session.user_id(),
4152            NotificationId::from_proto(request.notification_id),
4153        )
4154        .await?;
4155    send_notifications(
4156        &*session.connection_pool().await,
4157        &session.peer,
4158        notifications,
4159    );
4160    response.send(proto::Ack {})?;
4161    Ok(())
4162}
4163
4164/// Get the current users information
4165async fn get_private_user_info(
4166    _request: proto::GetPrivateUserInfo,
4167    response: Response<proto::GetPrivateUserInfo>,
4168    session: Session,
4169) -> Result<()> {
4170    let db = session.db().await;
4171
4172    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4173    let user = db
4174        .get_user_by_id(session.user_id())
4175        .await?
4176        .context("user not found")?;
4177    let flags = db.get_user_flags(session.user_id()).await?;
4178
4179    response.send(proto::GetPrivateUserInfoResponse {
4180        metrics_id,
4181        staff: user.admin,
4182        flags,
4183        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4184    })?;
4185    Ok(())
4186}
4187
4188/// Accept the terms of service (tos) on behalf of the current user
4189async fn accept_terms_of_service(
4190    _request: proto::AcceptTermsOfService,
4191    response: Response<proto::AcceptTermsOfService>,
4192    session: Session,
4193) -> Result<()> {
4194    let db = session.db().await;
4195
4196    let accepted_tos_at = Utc::now();
4197    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4198        .await?;
4199
4200    response.send(proto::AcceptTermsOfServiceResponse {
4201        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4202    })?;
4203
4204    // When the user accepts the terms of service, we want to refresh their LLM
4205    // token to grant access.
4206    session
4207        .peer
4208        .send(session.connection_id, proto::RefreshLlmToken {})?;
4209
4210    Ok(())
4211}
4212
4213async fn get_llm_api_token(
4214    _request: proto::GetLlmToken,
4215    response: Response<proto::GetLlmToken>,
4216    session: Session,
4217) -> Result<()> {
4218    let db = session.db().await;
4219
4220    let flags = db.get_user_flags(session.user_id()).await?;
4221
4222    let user_id = session.user_id();
4223    let user = db
4224        .get_user_by_id(user_id)
4225        .await?
4226        .with_context(|| format!("user {user_id} not found"))?;
4227
4228    if user.accepted_tos_at.is_none() {
4229        Err(anyhow!("terms of service not accepted"))?
4230    }
4231
4232    let stripe_client = session
4233        .app_state
4234        .stripe_client
4235        .as_ref()
4236        .context("failed to retrieve Stripe client")?;
4237
4238    let stripe_billing = session
4239        .app_state
4240        .stripe_billing
4241        .as_ref()
4242        .context("failed to retrieve Stripe billing object")?;
4243
4244    let billing_customer = if let Some(billing_customer) =
4245        db.get_billing_customer_by_user_id(user.id).await?
4246    {
4247        billing_customer
4248    } else {
4249        let customer_id = stripe_billing
4250            .find_or_create_customer_by_email(user.email_address.as_deref())
4251            .await?;
4252
4253        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4254            .await?
4255            .context("billing customer not found")?
4256    };
4257
4258    let billing_subscription =
4259        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4260            billing_subscription
4261        } else {
4262            let stripe_customer_id =
4263                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4264
4265            let stripe_subscription = stripe_billing
4266                .subscribe_to_zed_free(stripe_customer_id)
4267                .await?;
4268
4269            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4270                billing_customer_id: billing_customer.id,
4271                kind: Some(SubscriptionKind::ZedFree),
4272                stripe_subscription_id: stripe_subscription.id.to_string(),
4273                stripe_subscription_status: stripe_subscription.status.into(),
4274                stripe_cancellation_reason: None,
4275                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4276                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4277            })
4278            .await?
4279        };
4280
4281    let billing_preferences = db.get_billing_preferences(user.id).await?;
4282
4283    let token = LlmTokenClaims::create(
4284        &user,
4285        session.is_staff(),
4286        billing_customer,
4287        billing_preferences,
4288        &flags,
4289        billing_subscription,
4290        session.system_id.clone(),
4291        &session.app_state.config,
4292    )?;
4293    response.send(proto::GetLlmTokenResponse { token })?;
4294    Ok(())
4295}
4296
4297fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4298    let message = match message {
4299        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4300        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4301        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4302        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4303        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4304            code: frame.code.into(),
4305            reason: frame.reason.as_str().to_owned().into(),
4306        })),
4307        // We should never receive a frame while reading the message, according
4308        // to the `tungstenite` maintainers:
4309        //
4310        // > It cannot occur when you read messages from the WebSocket, but it
4311        // > can be used when you want to send the raw frames (e.g. you want to
4312        // > send the frames to the WebSocket without composing the full message first).
4313        // >
4314        // > — https://github.com/snapview/tungstenite-rs/issues/268
4315        TungsteniteMessage::Frame(_) => {
4316            bail!("received an unexpected frame while reading the message")
4317        }
4318    };
4319
4320    Ok(message)
4321}
4322
4323fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4324    match message {
4325        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4326        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4327        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4328        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4329        AxumMessage::Close(frame) => {
4330            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4331                code: frame.code.into(),
4332                reason: frame.reason.as_ref().into(),
4333            }))
4334        }
4335    }
4336}
4337
4338fn notify_membership_updated(
4339    connection_pool: &mut ConnectionPool,
4340    result: MembershipUpdated,
4341    user_id: UserId,
4342    peer: &Peer,
4343) {
4344    for membership in &result.new_channels.channel_memberships {
4345        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4346    }
4347    for channel_id in &result.removed_channels {
4348        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4349    }
4350
4351    let user_channels_update = proto::UpdateUserChannels {
4352        channel_memberships: result
4353            .new_channels
4354            .channel_memberships
4355            .iter()
4356            .map(|cm| proto::ChannelMembership {
4357                channel_id: cm.channel_id.to_proto(),
4358                role: cm.role.into(),
4359            })
4360            .collect(),
4361        ..Default::default()
4362    };
4363
4364    let mut update = build_channels_update(result.new_channels);
4365    update.delete_channels = result
4366        .removed_channels
4367        .into_iter()
4368        .map(|id| id.to_proto())
4369        .collect();
4370    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4371
4372    for connection_id in connection_pool.user_connection_ids(user_id) {
4373        peer.send(connection_id, user_channels_update.clone())
4374            .trace_err();
4375        peer.send(connection_id, update.clone()).trace_err();
4376    }
4377}
4378
4379fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4380    proto::UpdateUserChannels {
4381        channel_memberships: channels
4382            .channel_memberships
4383            .iter()
4384            .map(|m| proto::ChannelMembership {
4385                channel_id: m.channel_id.to_proto(),
4386                role: m.role.into(),
4387            })
4388            .collect(),
4389        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4390        observed_channel_message_id: channels.observed_channel_messages.clone(),
4391    }
4392}
4393
4394fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4395    let mut update = proto::UpdateChannels::default();
4396
4397    for channel in channels.channels {
4398        update.channels.push(channel.to_proto());
4399    }
4400
4401    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4402    update.latest_channel_message_ids = channels.latest_channel_messages;
4403
4404    for (channel_id, participants) in channels.channel_participants {
4405        update
4406            .channel_participants
4407            .push(proto::ChannelParticipants {
4408                channel_id: channel_id.to_proto(),
4409                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4410            });
4411    }
4412
4413    for channel in channels.invited_channels {
4414        update.channel_invitations.push(channel.to_proto());
4415    }
4416
4417    update
4418}
4419
4420fn build_initial_contacts_update(
4421    contacts: Vec<db::Contact>,
4422    pool: &ConnectionPool,
4423) -> proto::UpdateContacts {
4424    let mut update = proto::UpdateContacts::default();
4425
4426    for contact in contacts {
4427        match contact {
4428            db::Contact::Accepted { user_id, busy } => {
4429                update.contacts.push(contact_for_user(user_id, busy, pool));
4430            }
4431            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4432            db::Contact::Incoming { user_id } => {
4433                update
4434                    .incoming_requests
4435                    .push(proto::IncomingContactRequest {
4436                        requester_id: user_id.to_proto(),
4437                    })
4438            }
4439        }
4440    }
4441
4442    update
4443}
4444
4445fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4446    proto::Contact {
4447        user_id: user_id.to_proto(),
4448        online: pool.is_user_online(user_id),
4449        busy,
4450    }
4451}
4452
4453fn room_updated(room: &proto::Room, peer: &Peer) {
4454    broadcast(
4455        None,
4456        room.participants
4457            .iter()
4458            .filter_map(|participant| Some(participant.peer_id?.into())),
4459        |peer_id| {
4460            peer.send(
4461                peer_id,
4462                proto::RoomUpdated {
4463                    room: Some(room.clone()),
4464                },
4465            )
4466        },
4467    );
4468}
4469
4470fn channel_updated(
4471    channel: &db::channel::Model,
4472    room: &proto::Room,
4473    peer: &Peer,
4474    pool: &ConnectionPool,
4475) {
4476    let participants = room
4477        .participants
4478        .iter()
4479        .map(|p| p.user_id)
4480        .collect::<Vec<_>>();
4481
4482    broadcast(
4483        None,
4484        pool.channel_connection_ids(channel.root_id())
4485            .filter_map(|(channel_id, role)| {
4486                role.can_see_channel(channel.visibility)
4487                    .then_some(channel_id)
4488            }),
4489        |peer_id| {
4490            peer.send(
4491                peer_id,
4492                proto::UpdateChannels {
4493                    channel_participants: vec![proto::ChannelParticipants {
4494                        channel_id: channel.id.to_proto(),
4495                        participant_user_ids: participants.clone(),
4496                    }],
4497                    ..Default::default()
4498                },
4499            )
4500        },
4501    );
4502}
4503
4504async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4505    let db = session.db().await;
4506
4507    let contacts = db.get_contacts(user_id).await?;
4508    let busy = db.is_user_busy(user_id).await?;
4509
4510    let pool = session.connection_pool().await;
4511    let updated_contact = contact_for_user(user_id, busy, &pool);
4512    for contact in contacts {
4513        if let db::Contact::Accepted {
4514            user_id: contact_user_id,
4515            ..
4516        } = contact
4517        {
4518            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4519                session
4520                    .peer
4521                    .send(
4522                        contact_conn_id,
4523                        proto::UpdateContacts {
4524                            contacts: vec![updated_contact.clone()],
4525                            remove_contacts: Default::default(),
4526                            incoming_requests: Default::default(),
4527                            remove_incoming_requests: Default::default(),
4528                            outgoing_requests: Default::default(),
4529                            remove_outgoing_requests: Default::default(),
4530                        },
4531                    )
4532                    .trace_err();
4533            }
4534        }
4535    }
4536    Ok(())
4537}
4538
4539async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4540    let mut contacts_to_update = HashSet::default();
4541
4542    let room_id;
4543    let canceled_calls_to_user_ids;
4544    let livekit_room;
4545    let delete_livekit_room;
4546    let room;
4547    let channel;
4548
4549    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4550        contacts_to_update.insert(session.user_id());
4551
4552        for project in left_room.left_projects.values() {
4553            project_left(project, session);
4554        }
4555
4556        room_id = RoomId::from_proto(left_room.room.id);
4557        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4558        livekit_room = mem::take(&mut left_room.room.livekit_room);
4559        delete_livekit_room = left_room.deleted;
4560        room = mem::take(&mut left_room.room);
4561        channel = mem::take(&mut left_room.channel);
4562
4563        room_updated(&room, &session.peer);
4564    } else {
4565        return Ok(());
4566    }
4567
4568    if let Some(channel) = channel {
4569        channel_updated(
4570            &channel,
4571            &room,
4572            &session.peer,
4573            &*session.connection_pool().await,
4574        );
4575    }
4576
4577    {
4578        let pool = session.connection_pool().await;
4579        for canceled_user_id in canceled_calls_to_user_ids {
4580            for connection_id in pool.user_connection_ids(canceled_user_id) {
4581                session
4582                    .peer
4583                    .send(
4584                        connection_id,
4585                        proto::CallCanceled {
4586                            room_id: room_id.to_proto(),
4587                        },
4588                    )
4589                    .trace_err();
4590            }
4591            contacts_to_update.insert(canceled_user_id);
4592        }
4593    }
4594
4595    for contact_user_id in contacts_to_update {
4596        update_user_contacts(contact_user_id, session).await?;
4597    }
4598
4599    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4600        live_kit
4601            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4602            .await
4603            .trace_err();
4604
4605        if delete_livekit_room {
4606            live_kit.delete_room(livekit_room).await.trace_err();
4607        }
4608    }
4609
4610    Ok(())
4611}
4612
4613async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4614    let left_channel_buffers = session
4615        .db()
4616        .await
4617        .leave_channel_buffers(session.connection_id)
4618        .await?;
4619
4620    for left_buffer in left_channel_buffers {
4621        channel_buffer_updated(
4622            session.connection_id,
4623            left_buffer.connections,
4624            &proto::UpdateChannelBufferCollaborators {
4625                channel_id: left_buffer.channel_id.to_proto(),
4626                collaborators: left_buffer.collaborators,
4627            },
4628            &session.peer,
4629        );
4630    }
4631
4632    Ok(())
4633}
4634
4635fn project_left(project: &db::LeftProject, session: &Session) {
4636    for connection_id in &project.connection_ids {
4637        if project.should_unshare {
4638            session
4639                .peer
4640                .send(
4641                    *connection_id,
4642                    proto::UnshareProject {
4643                        project_id: project.id.to_proto(),
4644                    },
4645                )
4646                .trace_err();
4647        } else {
4648            session
4649                .peer
4650                .send(
4651                    *connection_id,
4652                    proto::RemoveProjectCollaborator {
4653                        project_id: project.id.to_proto(),
4654                        peer_id: Some(session.connection_id.into()),
4655                    },
4656                )
4657                .trace_err();
4658        }
4659    }
4660}
4661
4662pub trait ResultExt {
4663    type Ok;
4664
4665    fn trace_err(self) -> Option<Self::Ok>;
4666}
4667
4668impl<T, E> ResultExt for Result<T, E>
4669where
4670    E: std::fmt::Debug,
4671{
4672    type Ok = T;
4673
4674    #[track_caller]
4675    fn trace_err(self) -> Option<T> {
4676        match self {
4677            Ok(value) => Some(value),
4678            Err(error) => {
4679                tracing::error!("{:?}", error);
4680                None
4681            }
4682        }
4683    }
4684}