rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{
   8    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
   9    MIN_ACCOUNT_AGE_FOR_LLM_USE,
  10};
  11use crate::stripe_client::StripeCustomerId;
  12use crate::{
  13    AppState, Error, Result, auth,
  14    db::{
  15        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  16        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  17        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  18        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  19    },
  20    executor::Executor,
  21};
  22use anyhow::{Context as _, anyhow, bail};
  23use async_tungstenite::tungstenite::{
  24    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  25};
  26use axum::headers::UserAgent;
  27use axum::{
  28    Extension, Router, TypedHeader,
  29    body::Body,
  30    extract::{
  31        ConnectInfo, WebSocketUpgrade,
  32        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  33    },
  34    headers::{Header, HeaderName},
  35    http::StatusCode,
  36    middleware,
  37    response::IntoResponse,
  38    routing::get,
  39};
  40use chrono::Utc;
  41use collections::{HashMap, HashSet};
  42pub use connection_pool::{ConnectionPool, ZedVersion};
  43use core::fmt::{self, Debug, Formatter};
  44use reqwest_client::ReqwestClient;
  45use rpc::proto::split_repository_update;
  46use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  47
  48use futures::{
  49    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  50    stream::FuturesUnordered,
  51};
  52use prometheus::{IntGauge, register_int_gauge};
  53use rpc::{
  54    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  55    proto::{
  56        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  57        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  58    },
  59};
  60use semantic_version::SemanticVersion;
  61use serde::{Serialize, Serializer};
  62use std::{
  63    any::TypeId,
  64    future::Future,
  65    marker::PhantomData,
  66    mem,
  67    net::SocketAddr,
  68    ops::{Deref, DerefMut},
  69    rc::Rc,
  70    sync::{
  71        Arc, OnceLock,
  72        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  73    },
  74    time::{Duration, Instant},
  75};
  76use time::OffsetDateTime;
  77use tokio::sync::{Semaphore, watch};
  78use tower::ServiceBuilder;
  79use tracing::{
  80    Instrument,
  81    field::{self},
  82    info_span, instrument,
  83};
  84
  85pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  86
  87// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  88pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  89
  90const MESSAGE_COUNT_PER_PAGE: usize = 100;
  91const MAX_MESSAGE_LEN: usize = 1024;
  92const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  93const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  94
  95static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  96
  97type MessageHandler =
  98    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  99
 100pub struct ConnectionGuard;
 101
 102impl ConnectionGuard {
 103    pub fn try_acquire() -> Result<Self, ()> {
 104        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 105        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 106            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 107            tracing::error!(
 108                "too many concurrent connections: {}",
 109                current_connections + 1
 110            );
 111            return Err(());
 112        }
 113        Ok(ConnectionGuard)
 114    }
 115}
 116
 117impl Drop for ConnectionGuard {
 118    fn drop(&mut self) {
 119        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 120    }
 121}
 122
 123struct Response<R> {
 124    peer: Arc<Peer>,
 125    receipt: Receipt<R>,
 126    responded: Arc<AtomicBool>,
 127}
 128
 129impl<R: RequestMessage> Response<R> {
 130    fn send(self, payload: R::Response) -> Result<()> {
 131        self.responded.store(true, SeqCst);
 132        self.peer.respond(self.receipt, payload)?;
 133        Ok(())
 134    }
 135}
 136
 137#[derive(Clone, Debug)]
 138pub enum Principal {
 139    User(User),
 140    Impersonated { user: User, admin: User },
 141}
 142
 143impl Principal {
 144    fn user(&self) -> &User {
 145        match self {
 146            Principal::User(user) => user,
 147            Principal::Impersonated { user, .. } => user,
 148        }
 149    }
 150
 151    fn update_span(&self, span: &tracing::Span) {
 152        match &self {
 153            Principal::User(user) => {
 154                span.record("user_id", user.id.0);
 155                span.record("login", &user.github_login);
 156            }
 157            Principal::Impersonated { user, admin } => {
 158                span.record("user_id", user.id.0);
 159                span.record("login", &user.github_login);
 160                span.record("impersonator", &admin.github_login);
 161            }
 162        }
 163    }
 164}
 165
 166#[derive(Clone)]
 167struct Session {
 168    principal: Principal,
 169    connection_id: ConnectionId,
 170    db: Arc<tokio::sync::Mutex<DbHandle>>,
 171    peer: Arc<Peer>,
 172    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 173    app_state: Arc<AppState>,
 174    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 175    /// The GeoIP country code for the user.
 176    #[allow(unused)]
 177    geoip_country_code: Option<String>,
 178    system_id: Option<String>,
 179    _executor: Executor,
 180}
 181
 182impl Session {
 183    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 184        #[cfg(test)]
 185        tokio::task::yield_now().await;
 186        let guard = self.db.lock().await;
 187        #[cfg(test)]
 188        tokio::task::yield_now().await;
 189        guard
 190    }
 191
 192    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 193        #[cfg(test)]
 194        tokio::task::yield_now().await;
 195        let guard = self.connection_pool.lock();
 196        ConnectionPoolGuard {
 197            guard,
 198            _not_send: PhantomData,
 199        }
 200    }
 201
 202    fn is_staff(&self) -> bool {
 203        match &self.principal {
 204            Principal::User(user) => user.admin,
 205            Principal::Impersonated { .. } => true,
 206        }
 207    }
 208
 209    fn user_id(&self) -> UserId {
 210        match &self.principal {
 211            Principal::User(user) => user.id,
 212            Principal::Impersonated { user, .. } => user.id,
 213        }
 214    }
 215
 216    pub fn email(&self) -> Option<String> {
 217        match &self.principal {
 218            Principal::User(user) => user.email_address.clone(),
 219            Principal::Impersonated { user, .. } => user.email_address.clone(),
 220        }
 221    }
 222}
 223
 224impl Debug for Session {
 225    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 226        let mut result = f.debug_struct("Session");
 227        match &self.principal {
 228            Principal::User(user) => {
 229                result.field("user", &user.github_login);
 230            }
 231            Principal::Impersonated { user, admin } => {
 232                result.field("user", &user.github_login);
 233                result.field("impersonator", &admin.github_login);
 234            }
 235        }
 236        result.field("connection_id", &self.connection_id).finish()
 237    }
 238}
 239
 240struct DbHandle(Arc<Database>);
 241
 242impl Deref for DbHandle {
 243    type Target = Database;
 244
 245    fn deref(&self) -> &Self::Target {
 246        self.0.as_ref()
 247    }
 248}
 249
 250pub struct Server {
 251    id: parking_lot::Mutex<ServerId>,
 252    peer: Arc<Peer>,
 253    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 254    app_state: Arc<AppState>,
 255    handlers: HashMap<TypeId, MessageHandler>,
 256    teardown: watch::Sender<bool>,
 257}
 258
 259pub(crate) struct ConnectionPoolGuard<'a> {
 260    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 261    _not_send: PhantomData<Rc<()>>,
 262}
 263
 264#[derive(Serialize)]
 265pub struct ServerSnapshot<'a> {
 266    peer: &'a Peer,
 267    #[serde(serialize_with = "serialize_deref")]
 268    connection_pool: ConnectionPoolGuard<'a>,
 269}
 270
 271pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 272where
 273    S: Serializer,
 274    T: Deref<Target = U>,
 275    U: Serialize,
 276{
 277    Serialize::serialize(value.deref(), serializer)
 278}
 279
 280impl Server {
 281    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 282        let mut server = Self {
 283            id: parking_lot::Mutex::new(id),
 284            peer: Peer::new(id.0 as u32),
 285            app_state: app_state.clone(),
 286            connection_pool: Default::default(),
 287            handlers: Default::default(),
 288            teardown: watch::channel(false).0,
 289        };
 290
 291        server
 292            .add_request_handler(ping)
 293            .add_request_handler(create_room)
 294            .add_request_handler(join_room)
 295            .add_request_handler(rejoin_room)
 296            .add_request_handler(leave_room)
 297            .add_request_handler(set_room_participant_role)
 298            .add_request_handler(call)
 299            .add_request_handler(cancel_call)
 300            .add_message_handler(decline_call)
 301            .add_request_handler(update_participant_location)
 302            .add_request_handler(share_project)
 303            .add_message_handler(unshare_project)
 304            .add_request_handler(join_project)
 305            .add_message_handler(leave_project)
 306            .add_request_handler(update_project)
 307            .add_request_handler(update_worktree)
 308            .add_request_handler(update_repository)
 309            .add_request_handler(remove_repository)
 310            .add_message_handler(start_language_server)
 311            .add_message_handler(update_language_server)
 312            .add_message_handler(update_diagnostic_summary)
 313            .add_message_handler(update_worktree_settings)
 314            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 316            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 317            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 318            .add_request_handler(forward_find_search_candidates_request)
 319            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 320            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 321            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 322            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 323            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 324            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 325            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 326            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 327            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 328            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 329            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 330            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 331            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 332            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 333            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 334            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 335            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 336            .add_request_handler(
 337                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 338            )
 339            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 340            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 341            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 342            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 343            .add_request_handler(
 344                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 345            )
 346            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 347            .add_request_handler(
 348                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 349            )
 350            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 351            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 352            .add_request_handler(
 353                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 354            )
 355            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 356            .add_request_handler(
 357                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 358            )
 359            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 360            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 361            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 362            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 363            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 364            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 365            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 366            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 367            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 368            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 369            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 370            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 371            .add_request_handler(
 372                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 373            )
 374            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 375            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 376            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 377            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 378            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 379            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 380            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 381            .add_message_handler(create_buffer_for_peer)
 382            .add_request_handler(update_buffer)
 383            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 384            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 385            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 386            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 387            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 388            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 389            .add_message_handler(
 390                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 391            )
 392            .add_request_handler(get_users)
 393            .add_request_handler(fuzzy_search_users)
 394            .add_request_handler(request_contact)
 395            .add_request_handler(remove_contact)
 396            .add_request_handler(respond_to_contact_request)
 397            .add_message_handler(subscribe_to_channels)
 398            .add_request_handler(create_channel)
 399            .add_request_handler(delete_channel)
 400            .add_request_handler(invite_channel_member)
 401            .add_request_handler(remove_channel_member)
 402            .add_request_handler(set_channel_member_role)
 403            .add_request_handler(set_channel_visibility)
 404            .add_request_handler(rename_channel)
 405            .add_request_handler(join_channel_buffer)
 406            .add_request_handler(leave_channel_buffer)
 407            .add_message_handler(update_channel_buffer)
 408            .add_request_handler(rejoin_channel_buffers)
 409            .add_request_handler(get_channel_members)
 410            .add_request_handler(respond_to_channel_invite)
 411            .add_request_handler(join_channel)
 412            .add_request_handler(join_channel_chat)
 413            .add_message_handler(leave_channel_chat)
 414            .add_request_handler(send_channel_message)
 415            .add_request_handler(remove_channel_message)
 416            .add_request_handler(update_channel_message)
 417            .add_request_handler(get_channel_messages)
 418            .add_request_handler(get_channel_messages_by_id)
 419            .add_request_handler(get_notifications)
 420            .add_request_handler(mark_notification_as_read)
 421            .add_request_handler(move_channel)
 422            .add_request_handler(reorder_channel)
 423            .add_request_handler(follow)
 424            .add_message_handler(unfollow)
 425            .add_message_handler(update_followers)
 426            .add_request_handler(get_private_user_info)
 427            .add_request_handler(get_llm_api_token)
 428            .add_request_handler(accept_terms_of_service)
 429            .add_message_handler(acknowledge_channel_message)
 430            .add_message_handler(acknowledge_buffer_version)
 431            .add_request_handler(get_supermaven_api_key)
 432            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 433            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 434            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 435            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 436            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 437            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 438            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 439            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 440            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 441            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 442            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 443            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 444            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 445            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 446            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 447            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 448            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 449            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 450            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 451            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 452            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 453            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 454            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 455            .add_message_handler(update_context);
 456
 457        Arc::new(server)
 458    }
 459
 460    pub async fn start(&self) -> Result<()> {
 461        let server_id = *self.id.lock();
 462        let app_state = self.app_state.clone();
 463        let peer = self.peer.clone();
 464        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 465        let pool = self.connection_pool.clone();
 466        let livekit_client = self.app_state.livekit_client.clone();
 467
 468        let span = info_span!("start server");
 469        self.app_state.executor.spawn_detached(
 470            async move {
 471                tracing::info!("waiting for cleanup timeout");
 472                timeout.await;
 473                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 474
 475                app_state
 476                    .db
 477                    .delete_stale_channel_chat_participants(
 478                        &app_state.config.zed_environment,
 479                        server_id,
 480                    )
 481                    .await
 482                    .trace_err();
 483
 484                if let Some((room_ids, channel_ids)) = app_state
 485                    .db
 486                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 487                    .await
 488                    .trace_err()
 489                {
 490                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 491                    tracing::info!(
 492                        stale_channel_buffer_count = channel_ids.len(),
 493                        "retrieved stale channel buffers"
 494                    );
 495
 496                    for channel_id in channel_ids {
 497                        if let Some(refreshed_channel_buffer) = app_state
 498                            .db
 499                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 500                            .await
 501                            .trace_err()
 502                        {
 503                            for connection_id in refreshed_channel_buffer.connection_ids {
 504                                peer.send(
 505                                    connection_id,
 506                                    proto::UpdateChannelBufferCollaborators {
 507                                        channel_id: channel_id.to_proto(),
 508                                        collaborators: refreshed_channel_buffer
 509                                            .collaborators
 510                                            .clone(),
 511                                    },
 512                                )
 513                                .trace_err();
 514                            }
 515                        }
 516                    }
 517
 518                    for room_id in room_ids {
 519                        let mut contacts_to_update = HashSet::default();
 520                        let mut canceled_calls_to_user_ids = Vec::new();
 521                        let mut livekit_room = String::new();
 522                        let mut delete_livekit_room = false;
 523
 524                        if let Some(mut refreshed_room) = app_state
 525                            .db
 526                            .clear_stale_room_participants(room_id, server_id)
 527                            .await
 528                            .trace_err()
 529                        {
 530                            tracing::info!(
 531                                room_id = room_id.0,
 532                                new_participant_count = refreshed_room.room.participants.len(),
 533                                "refreshed room"
 534                            );
 535                            room_updated(&refreshed_room.room, &peer);
 536                            if let Some(channel) = refreshed_room.channel.as_ref() {
 537                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 538                            }
 539                            contacts_to_update
 540                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 541                            contacts_to_update
 542                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 543                            canceled_calls_to_user_ids =
 544                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 545                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 546                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 547                        }
 548
 549                        {
 550                            let pool = pool.lock();
 551                            for canceled_user_id in canceled_calls_to_user_ids {
 552                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 553                                    peer.send(
 554                                        connection_id,
 555                                        proto::CallCanceled {
 556                                            room_id: room_id.to_proto(),
 557                                        },
 558                                    )
 559                                    .trace_err();
 560                                }
 561                            }
 562                        }
 563
 564                        for user_id in contacts_to_update {
 565                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 566                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 567                            if let Some((busy, contacts)) = busy.zip(contacts) {
 568                                let pool = pool.lock();
 569                                let updated_contact = contact_for_user(user_id, busy, &pool);
 570                                for contact in contacts {
 571                                    if let db::Contact::Accepted {
 572                                        user_id: contact_user_id,
 573                                        ..
 574                                    } = contact
 575                                    {
 576                                        for contact_conn_id in
 577                                            pool.user_connection_ids(contact_user_id)
 578                                        {
 579                                            peer.send(
 580                                                contact_conn_id,
 581                                                proto::UpdateContacts {
 582                                                    contacts: vec![updated_contact.clone()],
 583                                                    remove_contacts: Default::default(),
 584                                                    incoming_requests: Default::default(),
 585                                                    remove_incoming_requests: Default::default(),
 586                                                    outgoing_requests: Default::default(),
 587                                                    remove_outgoing_requests: Default::default(),
 588                                                },
 589                                            )
 590                                            .trace_err();
 591                                        }
 592                                    }
 593                                }
 594                            }
 595                        }
 596
 597                        if let Some(live_kit) = livekit_client.as_ref() {
 598                            if delete_livekit_room {
 599                                live_kit.delete_room(livekit_room).await.trace_err();
 600                            }
 601                        }
 602                    }
 603                }
 604
 605                app_state
 606                    .db
 607                    .delete_stale_channel_chat_participants(
 608                        &app_state.config.zed_environment,
 609                        server_id,
 610                    )
 611                    .await
 612                    .trace_err();
 613
 614                app_state
 615                    .db
 616                    .clear_old_worktree_entries(server_id)
 617                    .await
 618                    .trace_err();
 619
 620                app_state
 621                    .db
 622                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 623                    .await
 624                    .trace_err();
 625            }
 626            .instrument(span),
 627        );
 628        Ok(())
 629    }
 630
 631    pub fn teardown(&self) {
 632        self.peer.teardown();
 633        self.connection_pool.lock().reset();
 634        let _ = self.teardown.send(true);
 635    }
 636
 637    #[cfg(test)]
 638    pub fn reset(&self, id: ServerId) {
 639        self.teardown();
 640        *self.id.lock() = id;
 641        self.peer.reset(id.0 as u32);
 642        let _ = self.teardown.send(false);
 643    }
 644
 645    #[cfg(test)]
 646    pub fn id(&self) -> ServerId {
 647        *self.id.lock()
 648    }
 649
 650    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 651    where
 652        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 653        Fut: 'static + Send + Future<Output = Result<()>>,
 654        M: EnvelopedMessage,
 655    {
 656        let prev_handler = self.handlers.insert(
 657            TypeId::of::<M>(),
 658            Box::new(move |envelope, session| {
 659                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 660                let received_at = envelope.received_at;
 661                tracing::info!("message received");
 662                let start_time = Instant::now();
 663                let future = (handler)(*envelope, session);
 664                async move {
 665                    let result = future.await;
 666                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 667                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 668                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 669                    let payload_type = M::NAME;
 670
 671                    match result {
 672                        Err(error) => {
 673                            tracing::error!(
 674                                ?error,
 675                                total_duration_ms,
 676                                processing_duration_ms,
 677                                queue_duration_ms,
 678                                payload_type,
 679                                "error handling message"
 680                            )
 681                        }
 682                        Ok(()) => tracing::info!(
 683                            total_duration_ms,
 684                            processing_duration_ms,
 685                            queue_duration_ms,
 686                            "finished handling message"
 687                        ),
 688                    }
 689                }
 690                .boxed()
 691            }),
 692        );
 693        if prev_handler.is_some() {
 694            panic!("registered a handler for the same message twice");
 695        }
 696        self
 697    }
 698
 699    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 700    where
 701        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 702        Fut: 'static + Send + Future<Output = Result<()>>,
 703        M: EnvelopedMessage,
 704    {
 705        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 706        self
 707    }
 708
 709    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 710    where
 711        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 712        Fut: Send + Future<Output = Result<()>>,
 713        M: RequestMessage,
 714    {
 715        let handler = Arc::new(handler);
 716        self.add_handler(move |envelope, session| {
 717            let receipt = envelope.receipt();
 718            let handler = handler.clone();
 719            async move {
 720                let peer = session.peer.clone();
 721                let responded = Arc::new(AtomicBool::default());
 722                let response = Response {
 723                    peer: peer.clone(),
 724                    responded: responded.clone(),
 725                    receipt,
 726                };
 727                match (handler)(envelope.payload, response, session).await {
 728                    Ok(()) => {
 729                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 730                            Ok(())
 731                        } else {
 732                            Err(anyhow!("handler did not send a response"))?
 733                        }
 734                    }
 735                    Err(error) => {
 736                        let proto_err = match &error {
 737                            Error::Internal(err) => err.to_proto(),
 738                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 739                        };
 740                        peer.respond_with_error(receipt, proto_err)?;
 741                        Err(error)
 742                    }
 743                }
 744            }
 745        })
 746    }
 747
 748    pub fn handle_connection(
 749        self: &Arc<Self>,
 750        connection: Connection,
 751        address: String,
 752        principal: Principal,
 753        zed_version: ZedVersion,
 754        user_agent: Option<String>,
 755        geoip_country_code: Option<String>,
 756        system_id: Option<String>,
 757        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 758        executor: Executor,
 759        connection_guard: Option<ConnectionGuard>,
 760    ) -> impl Future<Output = ()> + use<> {
 761        let this = self.clone();
 762        let span = info_span!("handle connection", %address,
 763            connection_id=field::Empty,
 764            user_id=field::Empty,
 765            login=field::Empty,
 766            impersonator=field::Empty,
 767            user_agent=field::Empty,
 768            geoip_country_code=field::Empty
 769        );
 770        principal.update_span(&span);
 771        if let Some(user_agent) = user_agent {
 772            span.record("user_agent", user_agent);
 773        }
 774
 775        if let Some(country_code) = geoip_country_code.as_ref() {
 776            span.record("geoip_country_code", country_code);
 777        }
 778
 779        let mut teardown = self.teardown.subscribe();
 780        async move {
 781            if *teardown.borrow() {
 782                tracing::error!("server is tearing down");
 783                return
 784            }
 785
 786            let (connection_id, handle_io, mut incoming_rx) = this
 787                .peer
 788                .add_connection(connection, {
 789                    let executor = executor.clone();
 790                    move |duration| executor.sleep(duration)
 791                });
 792            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 793
 794            tracing::info!("connection opened");
 795
 796            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 797            let http_client = match ReqwestClient::user_agent(&user_agent) {
 798                Ok(http_client) => Arc::new(http_client),
 799                Err(error) => {
 800                    tracing::error!(?error, "failed to create HTTP client");
 801                    return;
 802                }
 803            };
 804
 805            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 806                    supermaven_admin_api_key.to_string(),
 807                    http_client.clone(),
 808                )));
 809
 810            let session = Session {
 811                principal: principal.clone(),
 812                connection_id,
 813                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 814                peer: this.peer.clone(),
 815                connection_pool: this.connection_pool.clone(),
 816                app_state: this.app_state.clone(),
 817                geoip_country_code,
 818                system_id,
 819                _executor: executor.clone(),
 820                supermaven_client,
 821            };
 822
 823            if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
 824                tracing::error!(?error, "failed to send initial client update");
 825                return;
 826            }
 827            drop(connection_guard);
 828
 829            let handle_io = handle_io.fuse();
 830            futures::pin_mut!(handle_io);
 831
 832            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 833            // This prevents deadlocks when e.g., client A performs a request to client B and
 834            // client B performs a request to client A. If both clients stop processing further
 835            // messages until their respective request completes, they won't have a chance to
 836            // respond to the other client's request and cause a deadlock.
 837            //
 838            // This arrangement ensures we will attempt to process earlier messages first, but fall
 839            // back to processing messages arrived later in the spirit of making progress.
 840            let mut foreground_message_handlers = FuturesUnordered::new();
 841            let concurrent_handlers = Arc::new(Semaphore::new(512));
 842            loop {
 843                let next_message = async {
 844                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 845                    let message = incoming_rx.next().await;
 846                    (permit, message)
 847                }.fuse();
 848                futures::pin_mut!(next_message);
 849                futures::select_biased! {
 850                    _ = teardown.changed().fuse() => return,
 851                    result = handle_io => {
 852                        if let Err(error) = result {
 853                            tracing::error!(?error, "error handling I/O");
 854                        }
 855                        break;
 856                    }
 857                    _ = foreground_message_handlers.next() => {}
 858                    next_message = next_message => {
 859                        let (permit, message) = next_message;
 860                        if let Some(message) = message {
 861                            let type_name = message.payload_type_name();
 862                            // note: we copy all the fields from the parent span so we can query them in the logs.
 863                            // (https://github.com/tokio-rs/tracing/issues/2670).
 864                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 865                                user_id=field::Empty,
 866                                login=field::Empty,
 867                                impersonator=field::Empty,
 868                            );
 869                            principal.update_span(&span);
 870                            let span_enter = span.enter();
 871                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 872                                let is_background = message.is_background();
 873                                let handle_message = (handler)(message, session.clone());
 874                                drop(span_enter);
 875
 876                                let handle_message = async move {
 877                                    handle_message.await;
 878                                    drop(permit);
 879                                }.instrument(span);
 880                                if is_background {
 881                                    executor.spawn_detached(handle_message);
 882                                } else {
 883                                    foreground_message_handlers.push(handle_message);
 884                                }
 885                            } else {
 886                                tracing::error!("no message handler");
 887                            }
 888                        } else {
 889                            tracing::info!("connection closed");
 890                            break;
 891                        }
 892                    }
 893                }
 894            }
 895
 896            drop(foreground_message_handlers);
 897            tracing::info!("signing out");
 898            if let Err(error) = connection_lost(session, teardown, executor).await {
 899                tracing::error!(?error, "error signing out");
 900            }
 901
 902        }.instrument(span)
 903    }
 904
 905    async fn send_initial_client_update(
 906        &self,
 907        connection_id: ConnectionId,
 908        zed_version: ZedVersion,
 909        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 910        session: &Session,
 911    ) -> Result<()> {
 912        self.peer.send(
 913            connection_id,
 914            proto::Hello {
 915                peer_id: Some(connection_id.into()),
 916            },
 917        )?;
 918        tracing::info!("sent hello message");
 919        if let Some(send_connection_id) = send_connection_id.take() {
 920            let _ = send_connection_id.send(connection_id);
 921        }
 922
 923        match &session.principal {
 924            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 925                if !user.connected_once {
 926                    self.peer.send(connection_id, proto::ShowContacts {})?;
 927                    self.app_state
 928                        .db
 929                        .set_user_connected_once(user.id, true)
 930                        .await?;
 931                }
 932
 933                update_user_plan(session).await?;
 934
 935                let contacts = self.app_state.db.get_contacts(user.id).await?;
 936
 937                {
 938                    let mut pool = self.connection_pool.lock();
 939                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 940                    self.peer.send(
 941                        connection_id,
 942                        build_initial_contacts_update(contacts, &pool),
 943                    )?;
 944                }
 945
 946                if should_auto_subscribe_to_channels(zed_version) {
 947                    subscribe_user_to_channels(user.id, session).await?;
 948                }
 949
 950                if let Some(incoming_call) =
 951                    self.app_state.db.incoming_call_for_user(user.id).await?
 952                {
 953                    self.peer.send(connection_id, incoming_call)?;
 954                }
 955
 956                update_user_contacts(user.id, session).await?;
 957            }
 958        }
 959
 960        Ok(())
 961    }
 962
 963    pub async fn invite_code_redeemed(
 964        self: &Arc<Self>,
 965        inviter_id: UserId,
 966        invitee_id: UserId,
 967    ) -> Result<()> {
 968        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 969            if let Some(code) = &user.invite_code {
 970                let pool = self.connection_pool.lock();
 971                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 972                for connection_id in pool.user_connection_ids(inviter_id) {
 973                    self.peer.send(
 974                        connection_id,
 975                        proto::UpdateContacts {
 976                            contacts: vec![invitee_contact.clone()],
 977                            ..Default::default()
 978                        },
 979                    )?;
 980                    self.peer.send(
 981                        connection_id,
 982                        proto::UpdateInviteInfo {
 983                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 984                            count: user.invite_count as u32,
 985                        },
 986                    )?;
 987                }
 988            }
 989        }
 990        Ok(())
 991    }
 992
 993    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 994        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 995            if let Some(invite_code) = &user.invite_code {
 996                let pool = self.connection_pool.lock();
 997                for connection_id in pool.user_connection_ids(user_id) {
 998                    self.peer.send(
 999                        connection_id,
1000                        proto::UpdateInviteInfo {
1001                            url: format!(
1002                                "{}{}",
1003                                self.app_state.config.invite_link_prefix, invite_code
1004                            ),
1005                            count: user.invite_count as u32,
1006                        },
1007                    )?;
1008                }
1009            }
1010        }
1011        Ok(())
1012    }
1013
1014    pub async fn update_plan_for_user(
1015        self: &Arc<Self>,
1016        user_id: UserId,
1017        update_user_plan: proto::UpdateUserPlan,
1018    ) -> Result<()> {
1019        let pool = self.connection_pool.lock();
1020        for connection_id in pool.user_connection_ids(user_id) {
1021            self.peer
1022                .send(connection_id, update_user_plan.clone())
1023                .trace_err();
1024        }
1025
1026        Ok(())
1027    }
1028
1029    /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1030    /// message on the Collab server.
1031    ///
1032    /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1033    pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1034        let user = self
1035            .app_state
1036            .db
1037            .get_user_by_id(user_id)
1038            .await?
1039            .context("user not found")?;
1040
1041        let update_user_plan = make_update_user_plan_message(
1042            &user,
1043            user.admin,
1044            &self.app_state.db,
1045            self.app_state.llm_db.clone(),
1046        )
1047        .await?;
1048
1049        self.update_plan_for_user(user_id, update_user_plan).await
1050    }
1051
1052    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1053        let pool = self.connection_pool.lock();
1054        for connection_id in pool.user_connection_ids(user_id) {
1055            self.peer
1056                .send(connection_id, proto::RefreshLlmToken {})
1057                .trace_err();
1058        }
1059    }
1060
1061    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1062        ServerSnapshot {
1063            connection_pool: ConnectionPoolGuard {
1064                guard: self.connection_pool.lock(),
1065                _not_send: PhantomData,
1066            },
1067            peer: &self.peer,
1068        }
1069    }
1070}
1071
1072impl Deref for ConnectionPoolGuard<'_> {
1073    type Target = ConnectionPool;
1074
1075    fn deref(&self) -> &Self::Target {
1076        &self.guard
1077    }
1078}
1079
1080impl DerefMut for ConnectionPoolGuard<'_> {
1081    fn deref_mut(&mut self) -> &mut Self::Target {
1082        &mut self.guard
1083    }
1084}
1085
1086impl Drop for ConnectionPoolGuard<'_> {
1087    fn drop(&mut self) {
1088        #[cfg(test)]
1089        self.check_invariants();
1090    }
1091}
1092
1093fn broadcast<F>(
1094    sender_id: Option<ConnectionId>,
1095    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1096    mut f: F,
1097) where
1098    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1099{
1100    for receiver_id in receiver_ids {
1101        if Some(receiver_id) != sender_id {
1102            if let Err(error) = f(receiver_id) {
1103                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1104            }
1105        }
1106    }
1107}
1108
1109pub struct ProtocolVersion(u32);
1110
1111impl Header for ProtocolVersion {
1112    fn name() -> &'static HeaderName {
1113        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1114        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1115    }
1116
1117    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1118    where
1119        Self: Sized,
1120        I: Iterator<Item = &'i axum::http::HeaderValue>,
1121    {
1122        let version = values
1123            .next()
1124            .ok_or_else(axum::headers::Error::invalid)?
1125            .to_str()
1126            .map_err(|_| axum::headers::Error::invalid())?
1127            .parse()
1128            .map_err(|_| axum::headers::Error::invalid())?;
1129        Ok(Self(version))
1130    }
1131
1132    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1133        values.extend([self.0.to_string().parse().unwrap()]);
1134    }
1135}
1136
1137pub struct AppVersionHeader(SemanticVersion);
1138impl Header for AppVersionHeader {
1139    fn name() -> &'static HeaderName {
1140        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1141        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1142    }
1143
1144    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1145    where
1146        Self: Sized,
1147        I: Iterator<Item = &'i axum::http::HeaderValue>,
1148    {
1149        let version = values
1150            .next()
1151            .ok_or_else(axum::headers::Error::invalid)?
1152            .to_str()
1153            .map_err(|_| axum::headers::Error::invalid())?
1154            .parse()
1155            .map_err(|_| axum::headers::Error::invalid())?;
1156        Ok(Self(version))
1157    }
1158
1159    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1160        values.extend([self.0.to_string().parse().unwrap()]);
1161    }
1162}
1163
1164pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1165    Router::new()
1166        .route("/rpc", get(handle_websocket_request))
1167        .layer(
1168            ServiceBuilder::new()
1169                .layer(Extension(server.app_state.clone()))
1170                .layer(middleware::from_fn(auth::validate_header)),
1171        )
1172        .route("/metrics", get(handle_metrics))
1173        .layer(Extension(server))
1174}
1175
1176pub async fn handle_websocket_request(
1177    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1178    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1179    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1180    Extension(server): Extension<Arc<Server>>,
1181    Extension(principal): Extension<Principal>,
1182    user_agent: Option<TypedHeader<UserAgent>>,
1183    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1184    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1185    ws: WebSocketUpgrade,
1186) -> axum::response::Response {
1187    if protocol_version != rpc::PROTOCOL_VERSION {
1188        return (
1189            StatusCode::UPGRADE_REQUIRED,
1190            "client must be upgraded".to_string(),
1191        )
1192            .into_response();
1193    }
1194
1195    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1196        return (
1197            StatusCode::UPGRADE_REQUIRED,
1198            "no version header found".to_string(),
1199        )
1200            .into_response();
1201    };
1202
1203    if !version.can_collaborate() {
1204        return (
1205            StatusCode::UPGRADE_REQUIRED,
1206            "client must be upgraded".to_string(),
1207        )
1208            .into_response();
1209    }
1210
1211    let socket_address = socket_address.to_string();
1212
1213    // Acquire connection guard before WebSocket upgrade
1214    let connection_guard = match ConnectionGuard::try_acquire() {
1215        Ok(guard) => guard,
1216        Err(()) => {
1217            return (
1218                StatusCode::SERVICE_UNAVAILABLE,
1219                "Too many concurrent connections",
1220            )
1221                .into_response();
1222        }
1223    };
1224
1225    ws.on_upgrade(move |socket| {
1226        let socket = socket
1227            .map_ok(to_tungstenite_message)
1228            .err_into()
1229            .with(|message| async move { to_axum_message(message) });
1230        let connection = Connection::new(Box::pin(socket));
1231        async move {
1232            server
1233                .handle_connection(
1234                    connection,
1235                    socket_address,
1236                    principal,
1237                    version,
1238                    user_agent.map(|header| header.to_string()),
1239                    country_code_header.map(|header| header.to_string()),
1240                    system_id_header.map(|header| header.to_string()),
1241                    None,
1242                    Executor::Production,
1243                    Some(connection_guard),
1244                )
1245                .await;
1246        }
1247    })
1248}
1249
1250pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1251    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1252    let connections_metric = CONNECTIONS_METRIC
1253        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1254
1255    let connections = server
1256        .connection_pool
1257        .lock()
1258        .connections()
1259        .filter(|connection| !connection.admin)
1260        .count();
1261    connections_metric.set(connections as _);
1262
1263    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1264    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1265        register_int_gauge!(
1266            "shared_projects",
1267            "number of open projects with one or more guests"
1268        )
1269        .unwrap()
1270    });
1271
1272    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1273    shared_projects_metric.set(shared_projects as _);
1274
1275    let encoder = prometheus::TextEncoder::new();
1276    let metric_families = prometheus::gather();
1277    let encoded_metrics = encoder
1278        .encode_to_string(&metric_families)
1279        .map_err(|err| anyhow!("{err}"))?;
1280    Ok(encoded_metrics)
1281}
1282
1283#[instrument(err, skip(executor))]
1284async fn connection_lost(
1285    session: Session,
1286    mut teardown: watch::Receiver<bool>,
1287    executor: Executor,
1288) -> Result<()> {
1289    session.peer.disconnect(session.connection_id);
1290    session
1291        .connection_pool()
1292        .await
1293        .remove_connection(session.connection_id)?;
1294
1295    session
1296        .db()
1297        .await
1298        .connection_lost(session.connection_id)
1299        .await
1300        .trace_err();
1301
1302    futures::select_biased! {
1303        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1304
1305            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1306            leave_room_for_session(&session, session.connection_id).await.trace_err();
1307            leave_channel_buffers_for_session(&session)
1308                .await
1309                .trace_err();
1310
1311            if !session
1312                .connection_pool()
1313                .await
1314                .is_user_online(session.user_id())
1315            {
1316                let db = session.db().await;
1317                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1318                    room_updated(&room, &session.peer);
1319                }
1320            }
1321
1322            update_user_contacts(session.user_id(), &session).await?;
1323        },
1324        _ = teardown.changed().fuse() => {}
1325    }
1326
1327    Ok(())
1328}
1329
1330/// Acknowledges a ping from a client, used to keep the connection alive.
1331async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1332    response.send(proto::Ack {})?;
1333    Ok(())
1334}
1335
1336/// Creates a new room for calling (outside of channels)
1337async fn create_room(
1338    _request: proto::CreateRoom,
1339    response: Response<proto::CreateRoom>,
1340    session: Session,
1341) -> Result<()> {
1342    let livekit_room = nanoid::nanoid!(30);
1343
1344    let live_kit_connection_info = util::maybe!(async {
1345        let live_kit = session.app_state.livekit_client.as_ref();
1346        let live_kit = live_kit?;
1347        let user_id = session.user_id().to_string();
1348
1349        let token = live_kit
1350            .room_token(&livekit_room, &user_id.to_string())
1351            .trace_err()?;
1352
1353        Some(proto::LiveKitConnectionInfo {
1354            server_url: live_kit.url().into(),
1355            token,
1356            can_publish: true,
1357        })
1358    })
1359    .await;
1360
1361    let room = session
1362        .db()
1363        .await
1364        .create_room(session.user_id(), session.connection_id, &livekit_room)
1365        .await?;
1366
1367    response.send(proto::CreateRoomResponse {
1368        room: Some(room.clone()),
1369        live_kit_connection_info,
1370    })?;
1371
1372    update_user_contacts(session.user_id(), &session).await?;
1373    Ok(())
1374}
1375
1376/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1377async fn join_room(
1378    request: proto::JoinRoom,
1379    response: Response<proto::JoinRoom>,
1380    session: Session,
1381) -> Result<()> {
1382    let room_id = RoomId::from_proto(request.id);
1383
1384    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1385
1386    if let Some(channel_id) = channel_id {
1387        return join_channel_internal(channel_id, Box::new(response), session).await;
1388    }
1389
1390    let joined_room = {
1391        let room = session
1392            .db()
1393            .await
1394            .join_room(room_id, session.user_id(), session.connection_id)
1395            .await?;
1396        room_updated(&room.room, &session.peer);
1397        room.into_inner()
1398    };
1399
1400    for connection_id in session
1401        .connection_pool()
1402        .await
1403        .user_connection_ids(session.user_id())
1404    {
1405        session
1406            .peer
1407            .send(
1408                connection_id,
1409                proto::CallCanceled {
1410                    room_id: room_id.to_proto(),
1411                },
1412            )
1413            .trace_err();
1414    }
1415
1416    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1417    {
1418        live_kit
1419            .room_token(
1420                &joined_room.room.livekit_room,
1421                &session.user_id().to_string(),
1422            )
1423            .trace_err()
1424            .map(|token| proto::LiveKitConnectionInfo {
1425                server_url: live_kit.url().into(),
1426                token,
1427                can_publish: true,
1428            })
1429    } else {
1430        None
1431    };
1432
1433    response.send(proto::JoinRoomResponse {
1434        room: Some(joined_room.room),
1435        channel_id: None,
1436        live_kit_connection_info,
1437    })?;
1438
1439    update_user_contacts(session.user_id(), &session).await?;
1440    Ok(())
1441}
1442
1443/// Rejoin room is used to reconnect to a room after connection errors.
1444async fn rejoin_room(
1445    request: proto::RejoinRoom,
1446    response: Response<proto::RejoinRoom>,
1447    session: Session,
1448) -> Result<()> {
1449    let room;
1450    let channel;
1451    {
1452        let mut rejoined_room = session
1453            .db()
1454            .await
1455            .rejoin_room(request, session.user_id(), session.connection_id)
1456            .await?;
1457
1458        response.send(proto::RejoinRoomResponse {
1459            room: Some(rejoined_room.room.clone()),
1460            reshared_projects: rejoined_room
1461                .reshared_projects
1462                .iter()
1463                .map(|project| proto::ResharedProject {
1464                    id: project.id.to_proto(),
1465                    collaborators: project
1466                        .collaborators
1467                        .iter()
1468                        .map(|collaborator| collaborator.to_proto())
1469                        .collect(),
1470                })
1471                .collect(),
1472            rejoined_projects: rejoined_room
1473                .rejoined_projects
1474                .iter()
1475                .map(|rejoined_project| rejoined_project.to_proto())
1476                .collect(),
1477        })?;
1478        room_updated(&rejoined_room.room, &session.peer);
1479
1480        for project in &rejoined_room.reshared_projects {
1481            for collaborator in &project.collaborators {
1482                session
1483                    .peer
1484                    .send(
1485                        collaborator.connection_id,
1486                        proto::UpdateProjectCollaborator {
1487                            project_id: project.id.to_proto(),
1488                            old_peer_id: Some(project.old_connection_id.into()),
1489                            new_peer_id: Some(session.connection_id.into()),
1490                        },
1491                    )
1492                    .trace_err();
1493            }
1494
1495            broadcast(
1496                Some(session.connection_id),
1497                project
1498                    .collaborators
1499                    .iter()
1500                    .map(|collaborator| collaborator.connection_id),
1501                |connection_id| {
1502                    session.peer.forward_send(
1503                        session.connection_id,
1504                        connection_id,
1505                        proto::UpdateProject {
1506                            project_id: project.id.to_proto(),
1507                            worktrees: project.worktrees.clone(),
1508                        },
1509                    )
1510                },
1511            );
1512        }
1513
1514        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1515
1516        let rejoined_room = rejoined_room.into_inner();
1517
1518        room = rejoined_room.room;
1519        channel = rejoined_room.channel;
1520    }
1521
1522    if let Some(channel) = channel {
1523        channel_updated(
1524            &channel,
1525            &room,
1526            &session.peer,
1527            &*session.connection_pool().await,
1528        );
1529    }
1530
1531    update_user_contacts(session.user_id(), &session).await?;
1532    Ok(())
1533}
1534
1535fn notify_rejoined_projects(
1536    rejoined_projects: &mut Vec<RejoinedProject>,
1537    session: &Session,
1538) -> Result<()> {
1539    for project in rejoined_projects.iter() {
1540        for collaborator in &project.collaborators {
1541            session
1542                .peer
1543                .send(
1544                    collaborator.connection_id,
1545                    proto::UpdateProjectCollaborator {
1546                        project_id: project.id.to_proto(),
1547                        old_peer_id: Some(project.old_connection_id.into()),
1548                        new_peer_id: Some(session.connection_id.into()),
1549                    },
1550                )
1551                .trace_err();
1552        }
1553    }
1554
1555    for project in rejoined_projects {
1556        for worktree in mem::take(&mut project.worktrees) {
1557            // Stream this worktree's entries.
1558            let message = proto::UpdateWorktree {
1559                project_id: project.id.to_proto(),
1560                worktree_id: worktree.id,
1561                abs_path: worktree.abs_path.clone(),
1562                root_name: worktree.root_name,
1563                updated_entries: worktree.updated_entries,
1564                removed_entries: worktree.removed_entries,
1565                scan_id: worktree.scan_id,
1566                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1567                updated_repositories: worktree.updated_repositories,
1568                removed_repositories: worktree.removed_repositories,
1569            };
1570            for update in proto::split_worktree_update(message) {
1571                session.peer.send(session.connection_id, update)?;
1572            }
1573
1574            // Stream this worktree's diagnostics.
1575            for summary in worktree.diagnostic_summaries {
1576                session.peer.send(
1577                    session.connection_id,
1578                    proto::UpdateDiagnosticSummary {
1579                        project_id: project.id.to_proto(),
1580                        worktree_id: worktree.id,
1581                        summary: Some(summary),
1582                    },
1583                )?;
1584            }
1585
1586            for settings_file in worktree.settings_files {
1587                session.peer.send(
1588                    session.connection_id,
1589                    proto::UpdateWorktreeSettings {
1590                        project_id: project.id.to_proto(),
1591                        worktree_id: worktree.id,
1592                        path: settings_file.path,
1593                        content: Some(settings_file.content),
1594                        kind: Some(settings_file.kind.to_proto().into()),
1595                    },
1596                )?;
1597            }
1598        }
1599
1600        for repository in mem::take(&mut project.updated_repositories) {
1601            for update in split_repository_update(repository) {
1602                session.peer.send(session.connection_id, update)?;
1603            }
1604        }
1605
1606        for id in mem::take(&mut project.removed_repositories) {
1607            session.peer.send(
1608                session.connection_id,
1609                proto::RemoveRepository {
1610                    project_id: project.id.to_proto(),
1611                    id,
1612                },
1613            )?;
1614        }
1615    }
1616
1617    Ok(())
1618}
1619
1620/// leave room disconnects from the room.
1621async fn leave_room(
1622    _: proto::LeaveRoom,
1623    response: Response<proto::LeaveRoom>,
1624    session: Session,
1625) -> Result<()> {
1626    leave_room_for_session(&session, session.connection_id).await?;
1627    response.send(proto::Ack {})?;
1628    Ok(())
1629}
1630
1631/// Updates the permissions of someone else in the room.
1632async fn set_room_participant_role(
1633    request: proto::SetRoomParticipantRole,
1634    response: Response<proto::SetRoomParticipantRole>,
1635    session: Session,
1636) -> Result<()> {
1637    let user_id = UserId::from_proto(request.user_id);
1638    let role = ChannelRole::from(request.role());
1639
1640    let (livekit_room, can_publish) = {
1641        let room = session
1642            .db()
1643            .await
1644            .set_room_participant_role(
1645                session.user_id(),
1646                RoomId::from_proto(request.room_id),
1647                user_id,
1648                role,
1649            )
1650            .await?;
1651
1652        let livekit_room = room.livekit_room.clone();
1653        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1654        room_updated(&room, &session.peer);
1655        (livekit_room, can_publish)
1656    };
1657
1658    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1659        live_kit
1660            .update_participant(
1661                livekit_room.clone(),
1662                request.user_id.to_string(),
1663                livekit_api::proto::ParticipantPermission {
1664                    can_subscribe: true,
1665                    can_publish,
1666                    can_publish_data: can_publish,
1667                    hidden: false,
1668                    recorder: false,
1669                },
1670            )
1671            .await
1672            .trace_err();
1673    }
1674
1675    response.send(proto::Ack {})?;
1676    Ok(())
1677}
1678
1679/// Call someone else into the current room
1680async fn call(
1681    request: proto::Call,
1682    response: Response<proto::Call>,
1683    session: Session,
1684) -> Result<()> {
1685    let room_id = RoomId::from_proto(request.room_id);
1686    let calling_user_id = session.user_id();
1687    let calling_connection_id = session.connection_id;
1688    let called_user_id = UserId::from_proto(request.called_user_id);
1689    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1690    if !session
1691        .db()
1692        .await
1693        .has_contact(calling_user_id, called_user_id)
1694        .await?
1695    {
1696        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1697    }
1698
1699    let incoming_call = {
1700        let (room, incoming_call) = &mut *session
1701            .db()
1702            .await
1703            .call(
1704                room_id,
1705                calling_user_id,
1706                calling_connection_id,
1707                called_user_id,
1708                initial_project_id,
1709            )
1710            .await?;
1711        room_updated(room, &session.peer);
1712        mem::take(incoming_call)
1713    };
1714    update_user_contacts(called_user_id, &session).await?;
1715
1716    let mut calls = session
1717        .connection_pool()
1718        .await
1719        .user_connection_ids(called_user_id)
1720        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1721        .collect::<FuturesUnordered<_>>();
1722
1723    while let Some(call_response) = calls.next().await {
1724        match call_response.as_ref() {
1725            Ok(_) => {
1726                response.send(proto::Ack {})?;
1727                return Ok(());
1728            }
1729            Err(_) => {
1730                call_response.trace_err();
1731            }
1732        }
1733    }
1734
1735    {
1736        let room = session
1737            .db()
1738            .await
1739            .call_failed(room_id, called_user_id)
1740            .await?;
1741        room_updated(&room, &session.peer);
1742    }
1743    update_user_contacts(called_user_id, &session).await?;
1744
1745    Err(anyhow!("failed to ring user"))?
1746}
1747
1748/// Cancel an outgoing call.
1749async fn cancel_call(
1750    request: proto::CancelCall,
1751    response: Response<proto::CancelCall>,
1752    session: Session,
1753) -> Result<()> {
1754    let called_user_id = UserId::from_proto(request.called_user_id);
1755    let room_id = RoomId::from_proto(request.room_id);
1756    {
1757        let room = session
1758            .db()
1759            .await
1760            .cancel_call(room_id, session.connection_id, called_user_id)
1761            .await?;
1762        room_updated(&room, &session.peer);
1763    }
1764
1765    for connection_id in session
1766        .connection_pool()
1767        .await
1768        .user_connection_ids(called_user_id)
1769    {
1770        session
1771            .peer
1772            .send(
1773                connection_id,
1774                proto::CallCanceled {
1775                    room_id: room_id.to_proto(),
1776                },
1777            )
1778            .trace_err();
1779    }
1780    response.send(proto::Ack {})?;
1781
1782    update_user_contacts(called_user_id, &session).await?;
1783    Ok(())
1784}
1785
1786/// Decline an incoming call.
1787async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1788    let room_id = RoomId::from_proto(message.room_id);
1789    {
1790        let room = session
1791            .db()
1792            .await
1793            .decline_call(Some(room_id), session.user_id())
1794            .await?
1795            .context("declining call")?;
1796        room_updated(&room, &session.peer);
1797    }
1798
1799    for connection_id in session
1800        .connection_pool()
1801        .await
1802        .user_connection_ids(session.user_id())
1803    {
1804        session
1805            .peer
1806            .send(
1807                connection_id,
1808                proto::CallCanceled {
1809                    room_id: room_id.to_proto(),
1810                },
1811            )
1812            .trace_err();
1813    }
1814    update_user_contacts(session.user_id(), &session).await?;
1815    Ok(())
1816}
1817
1818/// Updates other participants in the room with your current location.
1819async fn update_participant_location(
1820    request: proto::UpdateParticipantLocation,
1821    response: Response<proto::UpdateParticipantLocation>,
1822    session: Session,
1823) -> Result<()> {
1824    let room_id = RoomId::from_proto(request.room_id);
1825    let location = request.location.context("invalid location")?;
1826
1827    let db = session.db().await;
1828    let room = db
1829        .update_room_participant_location(room_id, session.connection_id, location)
1830        .await?;
1831
1832    room_updated(&room, &session.peer);
1833    response.send(proto::Ack {})?;
1834    Ok(())
1835}
1836
1837/// Share a project into the room.
1838async fn share_project(
1839    request: proto::ShareProject,
1840    response: Response<proto::ShareProject>,
1841    session: Session,
1842) -> Result<()> {
1843    let (project_id, room) = &*session
1844        .db()
1845        .await
1846        .share_project(
1847            RoomId::from_proto(request.room_id),
1848            session.connection_id,
1849            &request.worktrees,
1850            request.is_ssh_project,
1851        )
1852        .await?;
1853    response.send(proto::ShareProjectResponse {
1854        project_id: project_id.to_proto(),
1855    })?;
1856    room_updated(room, &session.peer);
1857
1858    Ok(())
1859}
1860
1861/// Unshare a project from the room.
1862async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1863    let project_id = ProjectId::from_proto(message.project_id);
1864    unshare_project_internal(project_id, session.connection_id, &session).await
1865}
1866
1867async fn unshare_project_internal(
1868    project_id: ProjectId,
1869    connection_id: ConnectionId,
1870    session: &Session,
1871) -> Result<()> {
1872    let delete = {
1873        let room_guard = session
1874            .db()
1875            .await
1876            .unshare_project(project_id, connection_id)
1877            .await?;
1878
1879        let (delete, room, guest_connection_ids) = &*room_guard;
1880
1881        let message = proto::UnshareProject {
1882            project_id: project_id.to_proto(),
1883        };
1884
1885        broadcast(
1886            Some(connection_id),
1887            guest_connection_ids.iter().copied(),
1888            |conn_id| session.peer.send(conn_id, message.clone()),
1889        );
1890        if let Some(room) = room {
1891            room_updated(room, &session.peer);
1892        }
1893
1894        *delete
1895    };
1896
1897    if delete {
1898        let db = session.db().await;
1899        db.delete_project(project_id).await?;
1900    }
1901
1902    Ok(())
1903}
1904
1905/// Join someone elses shared project.
1906async fn join_project(
1907    request: proto::JoinProject,
1908    response: Response<proto::JoinProject>,
1909    session: Session,
1910) -> Result<()> {
1911    let project_id = ProjectId::from_proto(request.project_id);
1912
1913    tracing::info!(%project_id, "join project");
1914
1915    let db = session.db().await;
1916    let (project, replica_id) = &mut *db
1917        .join_project(
1918            project_id,
1919            session.connection_id,
1920            session.user_id(),
1921            request.committer_name.clone(),
1922            request.committer_email.clone(),
1923        )
1924        .await?;
1925    drop(db);
1926    tracing::info!(%project_id, "join remote project");
1927    let collaborators = project
1928        .collaborators
1929        .iter()
1930        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1931        .map(|collaborator| collaborator.to_proto())
1932        .collect::<Vec<_>>();
1933    let project_id = project.id;
1934    let guest_user_id = session.user_id();
1935
1936    let worktrees = project
1937        .worktrees
1938        .iter()
1939        .map(|(id, worktree)| proto::WorktreeMetadata {
1940            id: *id,
1941            root_name: worktree.root_name.clone(),
1942            visible: worktree.visible,
1943            abs_path: worktree.abs_path.clone(),
1944        })
1945        .collect::<Vec<_>>();
1946
1947    let add_project_collaborator = proto::AddProjectCollaborator {
1948        project_id: project_id.to_proto(),
1949        collaborator: Some(proto::Collaborator {
1950            peer_id: Some(session.connection_id.into()),
1951            replica_id: replica_id.0 as u32,
1952            user_id: guest_user_id.to_proto(),
1953            is_host: false,
1954            committer_name: request.committer_name.clone(),
1955            committer_email: request.committer_email.clone(),
1956        }),
1957    };
1958
1959    for collaborator in &collaborators {
1960        session
1961            .peer
1962            .send(
1963                collaborator.peer_id.unwrap().into(),
1964                add_project_collaborator.clone(),
1965            )
1966            .trace_err();
1967    }
1968
1969    // First, we send the metadata associated with each worktree.
1970    response.send(proto::JoinProjectResponse {
1971        project_id: project.id.0 as u64,
1972        worktrees: worktrees.clone(),
1973        replica_id: replica_id.0 as u32,
1974        collaborators: collaborators.clone(),
1975        language_servers: project.language_servers.clone(),
1976        role: project.role.into(),
1977    })?;
1978
1979    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1980        // Stream this worktree's entries.
1981        let message = proto::UpdateWorktree {
1982            project_id: project_id.to_proto(),
1983            worktree_id,
1984            abs_path: worktree.abs_path.clone(),
1985            root_name: worktree.root_name,
1986            updated_entries: worktree.entries,
1987            removed_entries: Default::default(),
1988            scan_id: worktree.scan_id,
1989            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1990            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1991            removed_repositories: Default::default(),
1992        };
1993        for update in proto::split_worktree_update(message) {
1994            session.peer.send(session.connection_id, update.clone())?;
1995        }
1996
1997        // Stream this worktree's diagnostics.
1998        for summary in worktree.diagnostic_summaries {
1999            session.peer.send(
2000                session.connection_id,
2001                proto::UpdateDiagnosticSummary {
2002                    project_id: project_id.to_proto(),
2003                    worktree_id: worktree.id,
2004                    summary: Some(summary),
2005                },
2006            )?;
2007        }
2008
2009        for settings_file in worktree.settings_files {
2010            session.peer.send(
2011                session.connection_id,
2012                proto::UpdateWorktreeSettings {
2013                    project_id: project_id.to_proto(),
2014                    worktree_id: worktree.id,
2015                    path: settings_file.path,
2016                    content: Some(settings_file.content),
2017                    kind: Some(settings_file.kind.to_proto() as i32),
2018                },
2019            )?;
2020        }
2021    }
2022
2023    for repository in mem::take(&mut project.repositories) {
2024        for update in split_repository_update(repository) {
2025            session.peer.send(session.connection_id, update)?;
2026        }
2027    }
2028
2029    for language_server in &project.language_servers {
2030        session.peer.send(
2031            session.connection_id,
2032            proto::UpdateLanguageServer {
2033                project_id: project_id.to_proto(),
2034                server_name: Some(language_server.name.clone()),
2035                language_server_id: language_server.id,
2036                variant: Some(
2037                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2038                        proto::LspDiskBasedDiagnosticsUpdated {},
2039                    ),
2040                ),
2041            },
2042        )?;
2043    }
2044
2045    Ok(())
2046}
2047
2048/// Leave someone elses shared project.
2049async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2050    let sender_id = session.connection_id;
2051    let project_id = ProjectId::from_proto(request.project_id);
2052    let db = session.db().await;
2053
2054    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2055    tracing::info!(
2056        %project_id,
2057        "leave project"
2058    );
2059
2060    project_left(project, &session);
2061    if let Some(room) = room {
2062        room_updated(room, &session.peer);
2063    }
2064
2065    Ok(())
2066}
2067
2068/// Updates other participants with changes to the project
2069async fn update_project(
2070    request: proto::UpdateProject,
2071    response: Response<proto::UpdateProject>,
2072    session: Session,
2073) -> Result<()> {
2074    let project_id = ProjectId::from_proto(request.project_id);
2075    let (room, guest_connection_ids) = &*session
2076        .db()
2077        .await
2078        .update_project(project_id, session.connection_id, &request.worktrees)
2079        .await?;
2080    broadcast(
2081        Some(session.connection_id),
2082        guest_connection_ids.iter().copied(),
2083        |connection_id| {
2084            session
2085                .peer
2086                .forward_send(session.connection_id, connection_id, request.clone())
2087        },
2088    );
2089    if let Some(room) = room {
2090        room_updated(room, &session.peer);
2091    }
2092    response.send(proto::Ack {})?;
2093
2094    Ok(())
2095}
2096
2097/// Updates other participants with changes to the worktree
2098async fn update_worktree(
2099    request: proto::UpdateWorktree,
2100    response: Response<proto::UpdateWorktree>,
2101    session: Session,
2102) -> Result<()> {
2103    let guest_connection_ids = session
2104        .db()
2105        .await
2106        .update_worktree(&request, session.connection_id)
2107        .await?;
2108
2109    broadcast(
2110        Some(session.connection_id),
2111        guest_connection_ids.iter().copied(),
2112        |connection_id| {
2113            session
2114                .peer
2115                .forward_send(session.connection_id, connection_id, request.clone())
2116        },
2117    );
2118    response.send(proto::Ack {})?;
2119    Ok(())
2120}
2121
2122async fn update_repository(
2123    request: proto::UpdateRepository,
2124    response: Response<proto::UpdateRepository>,
2125    session: Session,
2126) -> Result<()> {
2127    let guest_connection_ids = session
2128        .db()
2129        .await
2130        .update_repository(&request, session.connection_id)
2131        .await?;
2132
2133    broadcast(
2134        Some(session.connection_id),
2135        guest_connection_ids.iter().copied(),
2136        |connection_id| {
2137            session
2138                .peer
2139                .forward_send(session.connection_id, connection_id, request.clone())
2140        },
2141    );
2142    response.send(proto::Ack {})?;
2143    Ok(())
2144}
2145
2146async fn remove_repository(
2147    request: proto::RemoveRepository,
2148    response: Response<proto::RemoveRepository>,
2149    session: Session,
2150) -> Result<()> {
2151    let guest_connection_ids = session
2152        .db()
2153        .await
2154        .remove_repository(&request, session.connection_id)
2155        .await?;
2156
2157    broadcast(
2158        Some(session.connection_id),
2159        guest_connection_ids.iter().copied(),
2160        |connection_id| {
2161            session
2162                .peer
2163                .forward_send(session.connection_id, connection_id, request.clone())
2164        },
2165    );
2166    response.send(proto::Ack {})?;
2167    Ok(())
2168}
2169
2170/// Updates other participants with changes to the diagnostics
2171async fn update_diagnostic_summary(
2172    message: proto::UpdateDiagnosticSummary,
2173    session: Session,
2174) -> Result<()> {
2175    let guest_connection_ids = session
2176        .db()
2177        .await
2178        .update_diagnostic_summary(&message, session.connection_id)
2179        .await?;
2180
2181    broadcast(
2182        Some(session.connection_id),
2183        guest_connection_ids.iter().copied(),
2184        |connection_id| {
2185            session
2186                .peer
2187                .forward_send(session.connection_id, connection_id, message.clone())
2188        },
2189    );
2190
2191    Ok(())
2192}
2193
2194/// Updates other participants with changes to the worktree settings
2195async fn update_worktree_settings(
2196    message: proto::UpdateWorktreeSettings,
2197    session: Session,
2198) -> Result<()> {
2199    let guest_connection_ids = session
2200        .db()
2201        .await
2202        .update_worktree_settings(&message, session.connection_id)
2203        .await?;
2204
2205    broadcast(
2206        Some(session.connection_id),
2207        guest_connection_ids.iter().copied(),
2208        |connection_id| {
2209            session
2210                .peer
2211                .forward_send(session.connection_id, connection_id, message.clone())
2212        },
2213    );
2214
2215    Ok(())
2216}
2217
2218/// Notify other participants that a language server has started.
2219async fn start_language_server(
2220    request: proto::StartLanguageServer,
2221    session: Session,
2222) -> Result<()> {
2223    let guest_connection_ids = session
2224        .db()
2225        .await
2226        .start_language_server(&request, session.connection_id)
2227        .await?;
2228
2229    broadcast(
2230        Some(session.connection_id),
2231        guest_connection_ids.iter().copied(),
2232        |connection_id| {
2233            session
2234                .peer
2235                .forward_send(session.connection_id, connection_id, request.clone())
2236        },
2237    );
2238    Ok(())
2239}
2240
2241/// Notify other participants that a language server has changed.
2242async fn update_language_server(
2243    request: proto::UpdateLanguageServer,
2244    session: Session,
2245) -> Result<()> {
2246    let project_id = ProjectId::from_proto(request.project_id);
2247    let project_connection_ids = session
2248        .db()
2249        .await
2250        .project_connection_ids(project_id, session.connection_id, true)
2251        .await?;
2252    broadcast(
2253        Some(session.connection_id),
2254        project_connection_ids.iter().copied(),
2255        |connection_id| {
2256            session
2257                .peer
2258                .forward_send(session.connection_id, connection_id, request.clone())
2259        },
2260    );
2261    Ok(())
2262}
2263
2264/// forward a project request to the host. These requests should be read only
2265/// as guests are allowed to send them.
2266async fn forward_read_only_project_request<T>(
2267    request: T,
2268    response: Response<T>,
2269    session: Session,
2270) -> Result<()>
2271where
2272    T: EntityMessage + RequestMessage,
2273{
2274    let project_id = ProjectId::from_proto(request.remote_entity_id());
2275    let host_connection_id = session
2276        .db()
2277        .await
2278        .host_for_read_only_project_request(project_id, session.connection_id)
2279        .await?;
2280    let payload = session
2281        .peer
2282        .forward_request(session.connection_id, host_connection_id, request)
2283        .await?;
2284    response.send(payload)?;
2285    Ok(())
2286}
2287
2288async fn forward_find_search_candidates_request(
2289    request: proto::FindSearchCandidates,
2290    response: Response<proto::FindSearchCandidates>,
2291    session: Session,
2292) -> Result<()> {
2293    let project_id = ProjectId::from_proto(request.remote_entity_id());
2294    let host_connection_id = session
2295        .db()
2296        .await
2297        .host_for_read_only_project_request(project_id, session.connection_id)
2298        .await?;
2299    let payload = session
2300        .peer
2301        .forward_request(session.connection_id, host_connection_id, request)
2302        .await?;
2303    response.send(payload)?;
2304    Ok(())
2305}
2306
2307/// forward a project request to the host. These requests are disallowed
2308/// for guests.
2309async fn forward_mutating_project_request<T>(
2310    request: T,
2311    response: Response<T>,
2312    session: Session,
2313) -> Result<()>
2314where
2315    T: EntityMessage + RequestMessage,
2316{
2317    let project_id = ProjectId::from_proto(request.remote_entity_id());
2318
2319    let host_connection_id = session
2320        .db()
2321        .await
2322        .host_for_mutating_project_request(project_id, session.connection_id)
2323        .await?;
2324    let payload = session
2325        .peer
2326        .forward_request(session.connection_id, host_connection_id, request)
2327        .await?;
2328    response.send(payload)?;
2329    Ok(())
2330}
2331
2332/// Notify other participants that a new buffer has been created
2333async fn create_buffer_for_peer(
2334    request: proto::CreateBufferForPeer,
2335    session: Session,
2336) -> Result<()> {
2337    session
2338        .db()
2339        .await
2340        .check_user_is_project_host(
2341            ProjectId::from_proto(request.project_id),
2342            session.connection_id,
2343        )
2344        .await?;
2345    let peer_id = request.peer_id.context("invalid peer id")?;
2346    session
2347        .peer
2348        .forward_send(session.connection_id, peer_id.into(), request)?;
2349    Ok(())
2350}
2351
2352/// Notify other participants that a buffer has been updated. This is
2353/// allowed for guests as long as the update is limited to selections.
2354async fn update_buffer(
2355    request: proto::UpdateBuffer,
2356    response: Response<proto::UpdateBuffer>,
2357    session: Session,
2358) -> Result<()> {
2359    let project_id = ProjectId::from_proto(request.project_id);
2360    let mut capability = Capability::ReadOnly;
2361
2362    for op in request.operations.iter() {
2363        match op.variant {
2364            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2365            Some(_) => capability = Capability::ReadWrite,
2366        }
2367    }
2368
2369    let host = {
2370        let guard = session
2371            .db()
2372            .await
2373            .connections_for_buffer_update(project_id, session.connection_id, capability)
2374            .await?;
2375
2376        let (host, guests) = &*guard;
2377
2378        broadcast(
2379            Some(session.connection_id),
2380            guests.clone(),
2381            |connection_id| {
2382                session
2383                    .peer
2384                    .forward_send(session.connection_id, connection_id, request.clone())
2385            },
2386        );
2387
2388        *host
2389    };
2390
2391    if host != session.connection_id {
2392        session
2393            .peer
2394            .forward_request(session.connection_id, host, request.clone())
2395            .await?;
2396    }
2397
2398    response.send(proto::Ack {})?;
2399    Ok(())
2400}
2401
2402async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2403    let project_id = ProjectId::from_proto(message.project_id);
2404
2405    let operation = message.operation.as_ref().context("invalid operation")?;
2406    let capability = match operation.variant.as_ref() {
2407        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2408            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2409                match buffer_op.variant {
2410                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2411                        Capability::ReadOnly
2412                    }
2413                    _ => Capability::ReadWrite,
2414                }
2415            } else {
2416                Capability::ReadWrite
2417            }
2418        }
2419        Some(_) => Capability::ReadWrite,
2420        None => Capability::ReadOnly,
2421    };
2422
2423    let guard = session
2424        .db()
2425        .await
2426        .connections_for_buffer_update(project_id, session.connection_id, capability)
2427        .await?;
2428
2429    let (host, guests) = &*guard;
2430
2431    broadcast(
2432        Some(session.connection_id),
2433        guests.iter().chain([host]).copied(),
2434        |connection_id| {
2435            session
2436                .peer
2437                .forward_send(session.connection_id, connection_id, message.clone())
2438        },
2439    );
2440
2441    Ok(())
2442}
2443
2444/// Notify other participants that a project has been updated.
2445async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2446    request: T,
2447    session: Session,
2448) -> Result<()> {
2449    let project_id = ProjectId::from_proto(request.remote_entity_id());
2450    let project_connection_ids = session
2451        .db()
2452        .await
2453        .project_connection_ids(project_id, session.connection_id, false)
2454        .await?;
2455
2456    broadcast(
2457        Some(session.connection_id),
2458        project_connection_ids.iter().copied(),
2459        |connection_id| {
2460            session
2461                .peer
2462                .forward_send(session.connection_id, connection_id, request.clone())
2463        },
2464    );
2465    Ok(())
2466}
2467
2468/// Start following another user in a call.
2469async fn follow(
2470    request: proto::Follow,
2471    response: Response<proto::Follow>,
2472    session: Session,
2473) -> Result<()> {
2474    let room_id = RoomId::from_proto(request.room_id);
2475    let project_id = request.project_id.map(ProjectId::from_proto);
2476    let leader_id = request.leader_id.context("invalid leader id")?.into();
2477    let follower_id = session.connection_id;
2478
2479    session
2480        .db()
2481        .await
2482        .check_room_participants(room_id, leader_id, session.connection_id)
2483        .await?;
2484
2485    let response_payload = session
2486        .peer
2487        .forward_request(session.connection_id, leader_id, request)
2488        .await?;
2489    response.send(response_payload)?;
2490
2491    if let Some(project_id) = project_id {
2492        let room = session
2493            .db()
2494            .await
2495            .follow(room_id, project_id, leader_id, follower_id)
2496            .await?;
2497        room_updated(&room, &session.peer);
2498    }
2499
2500    Ok(())
2501}
2502
2503/// Stop following another user in a call.
2504async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2505    let room_id = RoomId::from_proto(request.room_id);
2506    let project_id = request.project_id.map(ProjectId::from_proto);
2507    let leader_id = request.leader_id.context("invalid leader id")?.into();
2508    let follower_id = session.connection_id;
2509
2510    session
2511        .db()
2512        .await
2513        .check_room_participants(room_id, leader_id, session.connection_id)
2514        .await?;
2515
2516    session
2517        .peer
2518        .forward_send(session.connection_id, leader_id, request)?;
2519
2520    if let Some(project_id) = project_id {
2521        let room = session
2522            .db()
2523            .await
2524            .unfollow(room_id, project_id, leader_id, follower_id)
2525            .await?;
2526        room_updated(&room, &session.peer);
2527    }
2528
2529    Ok(())
2530}
2531
2532/// Notify everyone following you of your current location.
2533async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2534    let room_id = RoomId::from_proto(request.room_id);
2535    let database = session.db.lock().await;
2536
2537    let connection_ids = if let Some(project_id) = request.project_id {
2538        let project_id = ProjectId::from_proto(project_id);
2539        database
2540            .project_connection_ids(project_id, session.connection_id, true)
2541            .await?
2542    } else {
2543        database
2544            .room_connection_ids(room_id, session.connection_id)
2545            .await?
2546    };
2547
2548    // For now, don't send view update messages back to that view's current leader.
2549    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2550        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2551        _ => None,
2552    });
2553
2554    for connection_id in connection_ids.iter().cloned() {
2555        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2556            session
2557                .peer
2558                .forward_send(session.connection_id, connection_id, request.clone())?;
2559        }
2560    }
2561    Ok(())
2562}
2563
2564/// Get public data about users.
2565async fn get_users(
2566    request: proto::GetUsers,
2567    response: Response<proto::GetUsers>,
2568    session: Session,
2569) -> Result<()> {
2570    let user_ids = request
2571        .user_ids
2572        .into_iter()
2573        .map(UserId::from_proto)
2574        .collect();
2575    let users = session
2576        .db()
2577        .await
2578        .get_users_by_ids(user_ids)
2579        .await?
2580        .into_iter()
2581        .map(|user| proto::User {
2582            id: user.id.to_proto(),
2583            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2584            github_login: user.github_login,
2585            name: user.name,
2586        })
2587        .collect();
2588    response.send(proto::UsersResponse { users })?;
2589    Ok(())
2590}
2591
2592/// Search for users (to invite) buy Github login
2593async fn fuzzy_search_users(
2594    request: proto::FuzzySearchUsers,
2595    response: Response<proto::FuzzySearchUsers>,
2596    session: Session,
2597) -> Result<()> {
2598    let query = request.query;
2599    let users = match query.len() {
2600        0 => vec![],
2601        1 | 2 => session
2602            .db()
2603            .await
2604            .get_user_by_github_login(&query)
2605            .await?
2606            .into_iter()
2607            .collect(),
2608        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2609    };
2610    let users = users
2611        .into_iter()
2612        .filter(|user| user.id != session.user_id())
2613        .map(|user| proto::User {
2614            id: user.id.to_proto(),
2615            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2616            github_login: user.github_login,
2617            name: user.name,
2618        })
2619        .collect();
2620    response.send(proto::UsersResponse { users })?;
2621    Ok(())
2622}
2623
2624/// Send a contact request to another user.
2625async fn request_contact(
2626    request: proto::RequestContact,
2627    response: Response<proto::RequestContact>,
2628    session: Session,
2629) -> Result<()> {
2630    let requester_id = session.user_id();
2631    let responder_id = UserId::from_proto(request.responder_id);
2632    if requester_id == responder_id {
2633        return Err(anyhow!("cannot add yourself as a contact"))?;
2634    }
2635
2636    let notifications = session
2637        .db()
2638        .await
2639        .send_contact_request(requester_id, responder_id)
2640        .await?;
2641
2642    // Update outgoing contact requests of requester
2643    let mut update = proto::UpdateContacts::default();
2644    update.outgoing_requests.push(responder_id.to_proto());
2645    for connection_id in session
2646        .connection_pool()
2647        .await
2648        .user_connection_ids(requester_id)
2649    {
2650        session.peer.send(connection_id, update.clone())?;
2651    }
2652
2653    // Update incoming contact requests of responder
2654    let mut update = proto::UpdateContacts::default();
2655    update
2656        .incoming_requests
2657        .push(proto::IncomingContactRequest {
2658            requester_id: requester_id.to_proto(),
2659        });
2660    let connection_pool = session.connection_pool().await;
2661    for connection_id in connection_pool.user_connection_ids(responder_id) {
2662        session.peer.send(connection_id, update.clone())?;
2663    }
2664
2665    send_notifications(&connection_pool, &session.peer, notifications);
2666
2667    response.send(proto::Ack {})?;
2668    Ok(())
2669}
2670
2671/// Accept or decline a contact request
2672async fn respond_to_contact_request(
2673    request: proto::RespondToContactRequest,
2674    response: Response<proto::RespondToContactRequest>,
2675    session: Session,
2676) -> Result<()> {
2677    let responder_id = session.user_id();
2678    let requester_id = UserId::from_proto(request.requester_id);
2679    let db = session.db().await;
2680    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2681        db.dismiss_contact_notification(responder_id, requester_id)
2682            .await?;
2683    } else {
2684        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2685
2686        let notifications = db
2687            .respond_to_contact_request(responder_id, requester_id, accept)
2688            .await?;
2689        let requester_busy = db.is_user_busy(requester_id).await?;
2690        let responder_busy = db.is_user_busy(responder_id).await?;
2691
2692        let pool = session.connection_pool().await;
2693        // Update responder with new contact
2694        let mut update = proto::UpdateContacts::default();
2695        if accept {
2696            update
2697                .contacts
2698                .push(contact_for_user(requester_id, requester_busy, &pool));
2699        }
2700        update
2701            .remove_incoming_requests
2702            .push(requester_id.to_proto());
2703        for connection_id in pool.user_connection_ids(responder_id) {
2704            session.peer.send(connection_id, update.clone())?;
2705        }
2706
2707        // Update requester with new contact
2708        let mut update = proto::UpdateContacts::default();
2709        if accept {
2710            update
2711                .contacts
2712                .push(contact_for_user(responder_id, responder_busy, &pool));
2713        }
2714        update
2715            .remove_outgoing_requests
2716            .push(responder_id.to_proto());
2717
2718        for connection_id in pool.user_connection_ids(requester_id) {
2719            session.peer.send(connection_id, update.clone())?;
2720        }
2721
2722        send_notifications(&pool, &session.peer, notifications);
2723    }
2724
2725    response.send(proto::Ack {})?;
2726    Ok(())
2727}
2728
2729/// Remove a contact.
2730async fn remove_contact(
2731    request: proto::RemoveContact,
2732    response: Response<proto::RemoveContact>,
2733    session: Session,
2734) -> Result<()> {
2735    let requester_id = session.user_id();
2736    let responder_id = UserId::from_proto(request.user_id);
2737    let db = session.db().await;
2738    let (contact_accepted, deleted_notification_id) =
2739        db.remove_contact(requester_id, responder_id).await?;
2740
2741    let pool = session.connection_pool().await;
2742    // Update outgoing contact requests of requester
2743    let mut update = proto::UpdateContacts::default();
2744    if contact_accepted {
2745        update.remove_contacts.push(responder_id.to_proto());
2746    } else {
2747        update
2748            .remove_outgoing_requests
2749            .push(responder_id.to_proto());
2750    }
2751    for connection_id in pool.user_connection_ids(requester_id) {
2752        session.peer.send(connection_id, update.clone())?;
2753    }
2754
2755    // Update incoming contact requests of responder
2756    let mut update = proto::UpdateContacts::default();
2757    if contact_accepted {
2758        update.remove_contacts.push(requester_id.to_proto());
2759    } else {
2760        update
2761            .remove_incoming_requests
2762            .push(requester_id.to_proto());
2763    }
2764    for connection_id in pool.user_connection_ids(responder_id) {
2765        session.peer.send(connection_id, update.clone())?;
2766        if let Some(notification_id) = deleted_notification_id {
2767            session.peer.send(
2768                connection_id,
2769                proto::DeleteNotification {
2770                    notification_id: notification_id.to_proto(),
2771                },
2772            )?;
2773        }
2774    }
2775
2776    response.send(proto::Ack {})?;
2777    Ok(())
2778}
2779
2780fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2781    version.0.minor() < 139
2782}
2783
2784async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2785    if is_staff {
2786        return Ok(proto::Plan::ZedPro);
2787    }
2788
2789    let subscription = db.get_active_billing_subscription(user_id).await?;
2790    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2791
2792    let plan = if let Some(subscription_kind) = subscription_kind {
2793        match subscription_kind {
2794            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2795            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2796            SubscriptionKind::ZedFree => proto::Plan::Free,
2797        }
2798    } else {
2799        proto::Plan::Free
2800    };
2801
2802    Ok(plan)
2803}
2804
2805async fn make_update_user_plan_message(
2806    user: &User,
2807    is_staff: bool,
2808    db: &Arc<Database>,
2809    llm_db: Option<Arc<LlmDatabase>>,
2810) -> Result<proto::UpdateUserPlan> {
2811    let feature_flags = db.get_user_flags(user.id).await?;
2812    let plan = current_plan(db, user.id, is_staff).await?;
2813    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2814    let billing_preferences = db.get_billing_preferences(user.id).await?;
2815
2816    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2817        let subscription = db.get_active_billing_subscription(user.id).await?;
2818
2819        let subscription_period =
2820            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2821
2822        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2823            llm_db
2824                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2825                .await?
2826        } else {
2827            None
2828        };
2829
2830        (subscription_period, usage)
2831    } else {
2832        (None, None)
2833    };
2834
2835    let bypass_account_age_check = feature_flags
2836        .iter()
2837        .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2838    let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2839        && !bypass_account_age_check
2840        && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2841
2842    Ok(proto::UpdateUserPlan {
2843        plan: plan.into(),
2844        trial_started_at: billing_customer
2845            .as_ref()
2846            .and_then(|billing_customer| billing_customer.trial_started_at)
2847            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2848        is_usage_based_billing_enabled: if is_staff {
2849            Some(true)
2850        } else {
2851            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2852        },
2853        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2854            proto::SubscriptionPeriod {
2855                started_at: started_at.timestamp() as u64,
2856                ended_at: ended_at.timestamp() as u64,
2857            }
2858        }),
2859        account_too_young: Some(account_too_young),
2860        has_overdue_invoices: billing_customer
2861            .map(|billing_customer| billing_customer.has_overdue_invoices),
2862        usage: Some(
2863            usage
2864                .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2865                .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2866        ),
2867    })
2868}
2869
2870fn model_requests_limit(
2871    plan: cloud_llm_client::Plan,
2872    feature_flags: &Vec<String>,
2873) -> cloud_llm_client::UsageLimit {
2874    match plan.model_requests_limit() {
2875        cloud_llm_client::UsageLimit::Limited(limit) => {
2876            let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2877                && feature_flags
2878                    .iter()
2879                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2880            {
2881                1_000
2882            } else {
2883                limit
2884            };
2885
2886            cloud_llm_client::UsageLimit::Limited(limit)
2887        }
2888        cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2889    }
2890}
2891
2892fn subscription_usage_to_proto(
2893    plan: proto::Plan,
2894    usage: crate::llm::db::subscription_usage::Model,
2895    feature_flags: &Vec<String>,
2896) -> proto::SubscriptionUsage {
2897    let plan = match plan {
2898        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2899        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2900        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2901    };
2902
2903    proto::SubscriptionUsage {
2904        model_requests_usage_amount: usage.model_requests as u32,
2905        model_requests_usage_limit: Some(proto::UsageLimit {
2906            variant: Some(match model_requests_limit(plan, feature_flags) {
2907                cloud_llm_client::UsageLimit::Limited(limit) => {
2908                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2909                        limit: limit as u32,
2910                    })
2911                }
2912                cloud_llm_client::UsageLimit::Unlimited => {
2913                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2914                }
2915            }),
2916        }),
2917        edit_predictions_usage_amount: usage.edit_predictions as u32,
2918        edit_predictions_usage_limit: Some(proto::UsageLimit {
2919            variant: Some(match plan.edit_predictions_limit() {
2920                cloud_llm_client::UsageLimit::Limited(limit) => {
2921                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2922                        limit: limit as u32,
2923                    })
2924                }
2925                cloud_llm_client::UsageLimit::Unlimited => {
2926                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2927                }
2928            }),
2929        }),
2930    }
2931}
2932
2933fn make_default_subscription_usage(
2934    plan: proto::Plan,
2935    feature_flags: &Vec<String>,
2936) -> proto::SubscriptionUsage {
2937    let plan = match plan {
2938        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2939        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2940        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2941    };
2942
2943    proto::SubscriptionUsage {
2944        model_requests_usage_amount: 0,
2945        model_requests_usage_limit: Some(proto::UsageLimit {
2946            variant: Some(match model_requests_limit(plan, feature_flags) {
2947                cloud_llm_client::UsageLimit::Limited(limit) => {
2948                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2949                        limit: limit as u32,
2950                    })
2951                }
2952                cloud_llm_client::UsageLimit::Unlimited => {
2953                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2954                }
2955            }),
2956        }),
2957        edit_predictions_usage_amount: 0,
2958        edit_predictions_usage_limit: Some(proto::UsageLimit {
2959            variant: Some(match plan.edit_predictions_limit() {
2960                cloud_llm_client::UsageLimit::Limited(limit) => {
2961                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2962                        limit: limit as u32,
2963                    })
2964                }
2965                cloud_llm_client::UsageLimit::Unlimited => {
2966                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2967                }
2968            }),
2969        }),
2970    }
2971}
2972
2973async fn update_user_plan(session: &Session) -> Result<()> {
2974    let db = session.db().await;
2975
2976    let update_user_plan = make_update_user_plan_message(
2977        session.principal.user(),
2978        session.is_staff(),
2979        &db.0,
2980        session.app_state.llm_db.clone(),
2981    )
2982    .await?;
2983
2984    session
2985        .peer
2986        .send(session.connection_id, update_user_plan)
2987        .trace_err();
2988
2989    Ok(())
2990}
2991
2992async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2993    subscribe_user_to_channels(session.user_id(), &session).await?;
2994    Ok(())
2995}
2996
2997async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2998    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2999    let mut pool = session.connection_pool().await;
3000    for membership in &channels_for_user.channel_memberships {
3001        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3002    }
3003    session.peer.send(
3004        session.connection_id,
3005        build_update_user_channels(&channels_for_user),
3006    )?;
3007    session.peer.send(
3008        session.connection_id,
3009        build_channels_update(channels_for_user),
3010    )?;
3011    Ok(())
3012}
3013
3014/// Creates a new channel.
3015async fn create_channel(
3016    request: proto::CreateChannel,
3017    response: Response<proto::CreateChannel>,
3018    session: Session,
3019) -> Result<()> {
3020    let db = session.db().await;
3021
3022    let parent_id = request.parent_id.map(ChannelId::from_proto);
3023    let (channel, membership) = db
3024        .create_channel(&request.name, parent_id, session.user_id())
3025        .await?;
3026
3027    let root_id = channel.root_id();
3028    let channel = Channel::from_model(channel);
3029
3030    response.send(proto::CreateChannelResponse {
3031        channel: Some(channel.to_proto()),
3032        parent_id: request.parent_id,
3033    })?;
3034
3035    let mut connection_pool = session.connection_pool().await;
3036    if let Some(membership) = membership {
3037        connection_pool.subscribe_to_channel(
3038            membership.user_id,
3039            membership.channel_id,
3040            membership.role,
3041        );
3042        let update = proto::UpdateUserChannels {
3043            channel_memberships: vec![proto::ChannelMembership {
3044                channel_id: membership.channel_id.to_proto(),
3045                role: membership.role.into(),
3046            }],
3047            ..Default::default()
3048        };
3049        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3050            session.peer.send(connection_id, update.clone())?;
3051        }
3052    }
3053
3054    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3055        if !role.can_see_channel(channel.visibility) {
3056            continue;
3057        }
3058
3059        let update = proto::UpdateChannels {
3060            channels: vec![channel.to_proto()],
3061            ..Default::default()
3062        };
3063        session.peer.send(connection_id, update.clone())?;
3064    }
3065
3066    Ok(())
3067}
3068
3069/// Delete a channel
3070async fn delete_channel(
3071    request: proto::DeleteChannel,
3072    response: Response<proto::DeleteChannel>,
3073    session: Session,
3074) -> Result<()> {
3075    let db = session.db().await;
3076
3077    let channel_id = request.channel_id;
3078    let (root_channel, removed_channels) = db
3079        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3080        .await?;
3081    response.send(proto::Ack {})?;
3082
3083    // Notify members of removed channels
3084    let mut update = proto::UpdateChannels::default();
3085    update
3086        .delete_channels
3087        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3088
3089    let connection_pool = session.connection_pool().await;
3090    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3091        session.peer.send(connection_id, update.clone())?;
3092    }
3093
3094    Ok(())
3095}
3096
3097/// Invite someone to join a channel.
3098async fn invite_channel_member(
3099    request: proto::InviteChannelMember,
3100    response: Response<proto::InviteChannelMember>,
3101    session: Session,
3102) -> Result<()> {
3103    let db = session.db().await;
3104    let channel_id = ChannelId::from_proto(request.channel_id);
3105    let invitee_id = UserId::from_proto(request.user_id);
3106    let InviteMemberResult {
3107        channel,
3108        notifications,
3109    } = db
3110        .invite_channel_member(
3111            channel_id,
3112            invitee_id,
3113            session.user_id(),
3114            request.role().into(),
3115        )
3116        .await?;
3117
3118    let update = proto::UpdateChannels {
3119        channel_invitations: vec![channel.to_proto()],
3120        ..Default::default()
3121    };
3122
3123    let connection_pool = session.connection_pool().await;
3124    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3125        session.peer.send(connection_id, update.clone())?;
3126    }
3127
3128    send_notifications(&connection_pool, &session.peer, notifications);
3129
3130    response.send(proto::Ack {})?;
3131    Ok(())
3132}
3133
3134/// remove someone from a channel
3135async fn remove_channel_member(
3136    request: proto::RemoveChannelMember,
3137    response: Response<proto::RemoveChannelMember>,
3138    session: Session,
3139) -> Result<()> {
3140    let db = session.db().await;
3141    let channel_id = ChannelId::from_proto(request.channel_id);
3142    let member_id = UserId::from_proto(request.user_id);
3143
3144    let RemoveChannelMemberResult {
3145        membership_update,
3146        notification_id,
3147    } = db
3148        .remove_channel_member(channel_id, member_id, session.user_id())
3149        .await?;
3150
3151    let mut connection_pool = session.connection_pool().await;
3152    notify_membership_updated(
3153        &mut connection_pool,
3154        membership_update,
3155        member_id,
3156        &session.peer,
3157    );
3158    for connection_id in connection_pool.user_connection_ids(member_id) {
3159        if let Some(notification_id) = notification_id {
3160            session
3161                .peer
3162                .send(
3163                    connection_id,
3164                    proto::DeleteNotification {
3165                        notification_id: notification_id.to_proto(),
3166                    },
3167                )
3168                .trace_err();
3169        }
3170    }
3171
3172    response.send(proto::Ack {})?;
3173    Ok(())
3174}
3175
3176/// Toggle the channel between public and private.
3177/// Care is taken to maintain the invariant that public channels only descend from public channels,
3178/// (though members-only channels can appear at any point in the hierarchy).
3179async fn set_channel_visibility(
3180    request: proto::SetChannelVisibility,
3181    response: Response<proto::SetChannelVisibility>,
3182    session: Session,
3183) -> Result<()> {
3184    let db = session.db().await;
3185    let channel_id = ChannelId::from_proto(request.channel_id);
3186    let visibility = request.visibility().into();
3187
3188    let channel_model = db
3189        .set_channel_visibility(channel_id, visibility, session.user_id())
3190        .await?;
3191    let root_id = channel_model.root_id();
3192    let channel = Channel::from_model(channel_model);
3193
3194    let mut connection_pool = session.connection_pool().await;
3195    for (user_id, role) in connection_pool
3196        .channel_user_ids(root_id)
3197        .collect::<Vec<_>>()
3198        .into_iter()
3199    {
3200        let update = if role.can_see_channel(channel.visibility) {
3201            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3202            proto::UpdateChannels {
3203                channels: vec![channel.to_proto()],
3204                ..Default::default()
3205            }
3206        } else {
3207            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3208            proto::UpdateChannels {
3209                delete_channels: vec![channel.id.to_proto()],
3210                ..Default::default()
3211            }
3212        };
3213
3214        for connection_id in connection_pool.user_connection_ids(user_id) {
3215            session.peer.send(connection_id, update.clone())?;
3216        }
3217    }
3218
3219    response.send(proto::Ack {})?;
3220    Ok(())
3221}
3222
3223/// Alter the role for a user in the channel.
3224async fn set_channel_member_role(
3225    request: proto::SetChannelMemberRole,
3226    response: Response<proto::SetChannelMemberRole>,
3227    session: Session,
3228) -> Result<()> {
3229    let db = session.db().await;
3230    let channel_id = ChannelId::from_proto(request.channel_id);
3231    let member_id = UserId::from_proto(request.user_id);
3232    let result = db
3233        .set_channel_member_role(
3234            channel_id,
3235            session.user_id(),
3236            member_id,
3237            request.role().into(),
3238        )
3239        .await?;
3240
3241    match result {
3242        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3243            let mut connection_pool = session.connection_pool().await;
3244            notify_membership_updated(
3245                &mut connection_pool,
3246                membership_update,
3247                member_id,
3248                &session.peer,
3249            )
3250        }
3251        db::SetMemberRoleResult::InviteUpdated(channel) => {
3252            let update = proto::UpdateChannels {
3253                channel_invitations: vec![channel.to_proto()],
3254                ..Default::default()
3255            };
3256
3257            for connection_id in session
3258                .connection_pool()
3259                .await
3260                .user_connection_ids(member_id)
3261            {
3262                session.peer.send(connection_id, update.clone())?;
3263            }
3264        }
3265    }
3266
3267    response.send(proto::Ack {})?;
3268    Ok(())
3269}
3270
3271/// Change the name of a channel
3272async fn rename_channel(
3273    request: proto::RenameChannel,
3274    response: Response<proto::RenameChannel>,
3275    session: Session,
3276) -> Result<()> {
3277    let db = session.db().await;
3278    let channel_id = ChannelId::from_proto(request.channel_id);
3279    let channel_model = db
3280        .rename_channel(channel_id, session.user_id(), &request.name)
3281        .await?;
3282    let root_id = channel_model.root_id();
3283    let channel = Channel::from_model(channel_model);
3284
3285    response.send(proto::RenameChannelResponse {
3286        channel: Some(channel.to_proto()),
3287    })?;
3288
3289    let connection_pool = session.connection_pool().await;
3290    let update = proto::UpdateChannels {
3291        channels: vec![channel.to_proto()],
3292        ..Default::default()
3293    };
3294    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3295        if role.can_see_channel(channel.visibility) {
3296            session.peer.send(connection_id, update.clone())?;
3297        }
3298    }
3299
3300    Ok(())
3301}
3302
3303/// Move a channel to a new parent.
3304async fn move_channel(
3305    request: proto::MoveChannel,
3306    response: Response<proto::MoveChannel>,
3307    session: Session,
3308) -> Result<()> {
3309    let channel_id = ChannelId::from_proto(request.channel_id);
3310    let to = ChannelId::from_proto(request.to);
3311
3312    let (root_id, channels) = session
3313        .db()
3314        .await
3315        .move_channel(channel_id, to, session.user_id())
3316        .await?;
3317
3318    let connection_pool = session.connection_pool().await;
3319    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3320        let channels = channels
3321            .iter()
3322            .filter_map(|channel| {
3323                if role.can_see_channel(channel.visibility) {
3324                    Some(channel.to_proto())
3325                } else {
3326                    None
3327                }
3328            })
3329            .collect::<Vec<_>>();
3330        if channels.is_empty() {
3331            continue;
3332        }
3333
3334        let update = proto::UpdateChannels {
3335            channels,
3336            ..Default::default()
3337        };
3338
3339        session.peer.send(connection_id, update.clone())?;
3340    }
3341
3342    response.send(Ack {})?;
3343    Ok(())
3344}
3345
3346async fn reorder_channel(
3347    request: proto::ReorderChannel,
3348    response: Response<proto::ReorderChannel>,
3349    session: Session,
3350) -> Result<()> {
3351    let channel_id = ChannelId::from_proto(request.channel_id);
3352    let direction = request.direction();
3353
3354    let updated_channels = session
3355        .db()
3356        .await
3357        .reorder_channel(channel_id, direction, session.user_id())
3358        .await?;
3359
3360    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3361        let connection_pool = session.connection_pool().await;
3362        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3363            let channels = updated_channels
3364                .iter()
3365                .filter_map(|channel| {
3366                    if role.can_see_channel(channel.visibility) {
3367                        Some(channel.to_proto())
3368                    } else {
3369                        None
3370                    }
3371                })
3372                .collect::<Vec<_>>();
3373
3374            if channels.is_empty() {
3375                continue;
3376            }
3377
3378            let update = proto::UpdateChannels {
3379                channels,
3380                ..Default::default()
3381            };
3382
3383            session.peer.send(connection_id, update.clone())?;
3384        }
3385    }
3386
3387    response.send(Ack {})?;
3388    Ok(())
3389}
3390
3391/// Get the list of channel members
3392async fn get_channel_members(
3393    request: proto::GetChannelMembers,
3394    response: Response<proto::GetChannelMembers>,
3395    session: Session,
3396) -> Result<()> {
3397    let db = session.db().await;
3398    let channel_id = ChannelId::from_proto(request.channel_id);
3399    let limit = if request.limit == 0 {
3400        u16::MAX as u64
3401    } else {
3402        request.limit
3403    };
3404    let (members, users) = db
3405        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3406        .await?;
3407    response.send(proto::GetChannelMembersResponse { members, users })?;
3408    Ok(())
3409}
3410
3411/// Accept or decline a channel invitation.
3412async fn respond_to_channel_invite(
3413    request: proto::RespondToChannelInvite,
3414    response: Response<proto::RespondToChannelInvite>,
3415    session: Session,
3416) -> Result<()> {
3417    let db = session.db().await;
3418    let channel_id = ChannelId::from_proto(request.channel_id);
3419    let RespondToChannelInvite {
3420        membership_update,
3421        notifications,
3422    } = db
3423        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3424        .await?;
3425
3426    let mut connection_pool = session.connection_pool().await;
3427    if let Some(membership_update) = membership_update {
3428        notify_membership_updated(
3429            &mut connection_pool,
3430            membership_update,
3431            session.user_id(),
3432            &session.peer,
3433        );
3434    } else {
3435        let update = proto::UpdateChannels {
3436            remove_channel_invitations: vec![channel_id.to_proto()],
3437            ..Default::default()
3438        };
3439
3440        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3441            session.peer.send(connection_id, update.clone())?;
3442        }
3443    };
3444
3445    send_notifications(&connection_pool, &session.peer, notifications);
3446
3447    response.send(proto::Ack {})?;
3448
3449    Ok(())
3450}
3451
3452/// Join the channels' room
3453async fn join_channel(
3454    request: proto::JoinChannel,
3455    response: Response<proto::JoinChannel>,
3456    session: Session,
3457) -> Result<()> {
3458    let channel_id = ChannelId::from_proto(request.channel_id);
3459    join_channel_internal(channel_id, Box::new(response), session).await
3460}
3461
3462trait JoinChannelInternalResponse {
3463    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3464}
3465impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3466    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3467        Response::<proto::JoinChannel>::send(self, result)
3468    }
3469}
3470impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3471    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3472        Response::<proto::JoinRoom>::send(self, result)
3473    }
3474}
3475
3476async fn join_channel_internal(
3477    channel_id: ChannelId,
3478    response: Box<impl JoinChannelInternalResponse>,
3479    session: Session,
3480) -> Result<()> {
3481    let joined_room = {
3482        let mut db = session.db().await;
3483        // If zed quits without leaving the room, and the user re-opens zed before the
3484        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3485        // room they were in.
3486        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3487            tracing::info!(
3488                stale_connection_id = %connection,
3489                "cleaning up stale connection",
3490            );
3491            drop(db);
3492            leave_room_for_session(&session, connection).await?;
3493            db = session.db().await;
3494        }
3495
3496        let (joined_room, membership_updated, role) = db
3497            .join_channel(channel_id, session.user_id(), session.connection_id)
3498            .await?;
3499
3500        let live_kit_connection_info =
3501            session
3502                .app_state
3503                .livekit_client
3504                .as_ref()
3505                .and_then(|live_kit| {
3506                    let (can_publish, token) = if role == ChannelRole::Guest {
3507                        (
3508                            false,
3509                            live_kit
3510                                .guest_token(
3511                                    &joined_room.room.livekit_room,
3512                                    &session.user_id().to_string(),
3513                                )
3514                                .trace_err()?,
3515                        )
3516                    } else {
3517                        (
3518                            true,
3519                            live_kit
3520                                .room_token(
3521                                    &joined_room.room.livekit_room,
3522                                    &session.user_id().to_string(),
3523                                )
3524                                .trace_err()?,
3525                        )
3526                    };
3527
3528                    Some(LiveKitConnectionInfo {
3529                        server_url: live_kit.url().into(),
3530                        token,
3531                        can_publish,
3532                    })
3533                });
3534
3535        response.send(proto::JoinRoomResponse {
3536            room: Some(joined_room.room.clone()),
3537            channel_id: joined_room
3538                .channel
3539                .as_ref()
3540                .map(|channel| channel.id.to_proto()),
3541            live_kit_connection_info,
3542        })?;
3543
3544        let mut connection_pool = session.connection_pool().await;
3545        if let Some(membership_updated) = membership_updated {
3546            notify_membership_updated(
3547                &mut connection_pool,
3548                membership_updated,
3549                session.user_id(),
3550                &session.peer,
3551            );
3552        }
3553
3554        room_updated(&joined_room.room, &session.peer);
3555
3556        joined_room
3557    };
3558
3559    channel_updated(
3560        &joined_room.channel.context("channel not returned")?,
3561        &joined_room.room,
3562        &session.peer,
3563        &*session.connection_pool().await,
3564    );
3565
3566    update_user_contacts(session.user_id(), &session).await?;
3567    Ok(())
3568}
3569
3570/// Start editing the channel notes
3571async fn join_channel_buffer(
3572    request: proto::JoinChannelBuffer,
3573    response: Response<proto::JoinChannelBuffer>,
3574    session: Session,
3575) -> Result<()> {
3576    let db = session.db().await;
3577    let channel_id = ChannelId::from_proto(request.channel_id);
3578
3579    let open_response = db
3580        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3581        .await?;
3582
3583    let collaborators = open_response.collaborators.clone();
3584    response.send(open_response)?;
3585
3586    let update = UpdateChannelBufferCollaborators {
3587        channel_id: channel_id.to_proto(),
3588        collaborators: collaborators.clone(),
3589    };
3590    channel_buffer_updated(
3591        session.connection_id,
3592        collaborators
3593            .iter()
3594            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3595        &update,
3596        &session.peer,
3597    );
3598
3599    Ok(())
3600}
3601
3602/// Edit the channel notes
3603async fn update_channel_buffer(
3604    request: proto::UpdateChannelBuffer,
3605    session: Session,
3606) -> Result<()> {
3607    let db = session.db().await;
3608    let channel_id = ChannelId::from_proto(request.channel_id);
3609
3610    let (collaborators, epoch, version) = db
3611        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3612        .await?;
3613
3614    channel_buffer_updated(
3615        session.connection_id,
3616        collaborators.clone(),
3617        &proto::UpdateChannelBuffer {
3618            channel_id: channel_id.to_proto(),
3619            operations: request.operations,
3620        },
3621        &session.peer,
3622    );
3623
3624    let pool = &*session.connection_pool().await;
3625
3626    let non_collaborators =
3627        pool.channel_connection_ids(channel_id)
3628            .filter_map(|(connection_id, _)| {
3629                if collaborators.contains(&connection_id) {
3630                    None
3631                } else {
3632                    Some(connection_id)
3633                }
3634            });
3635
3636    broadcast(None, non_collaborators, |peer_id| {
3637        session.peer.send(
3638            peer_id,
3639            proto::UpdateChannels {
3640                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3641                    channel_id: channel_id.to_proto(),
3642                    epoch: epoch as u64,
3643                    version: version.clone(),
3644                }],
3645                ..Default::default()
3646            },
3647        )
3648    });
3649
3650    Ok(())
3651}
3652
3653/// Rejoin the channel notes after a connection blip
3654async fn rejoin_channel_buffers(
3655    request: proto::RejoinChannelBuffers,
3656    response: Response<proto::RejoinChannelBuffers>,
3657    session: Session,
3658) -> Result<()> {
3659    let db = session.db().await;
3660    let buffers = db
3661        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3662        .await?;
3663
3664    for rejoined_buffer in &buffers {
3665        let collaborators_to_notify = rejoined_buffer
3666            .buffer
3667            .collaborators
3668            .iter()
3669            .filter_map(|c| Some(c.peer_id?.into()));
3670        channel_buffer_updated(
3671            session.connection_id,
3672            collaborators_to_notify,
3673            &proto::UpdateChannelBufferCollaborators {
3674                channel_id: rejoined_buffer.buffer.channel_id,
3675                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3676            },
3677            &session.peer,
3678        );
3679    }
3680
3681    response.send(proto::RejoinChannelBuffersResponse {
3682        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3683    })?;
3684
3685    Ok(())
3686}
3687
3688/// Stop editing the channel notes
3689async fn leave_channel_buffer(
3690    request: proto::LeaveChannelBuffer,
3691    response: Response<proto::LeaveChannelBuffer>,
3692    session: Session,
3693) -> Result<()> {
3694    let db = session.db().await;
3695    let channel_id = ChannelId::from_proto(request.channel_id);
3696
3697    let left_buffer = db
3698        .leave_channel_buffer(channel_id, session.connection_id)
3699        .await?;
3700
3701    response.send(Ack {})?;
3702
3703    channel_buffer_updated(
3704        session.connection_id,
3705        left_buffer.connections,
3706        &proto::UpdateChannelBufferCollaborators {
3707            channel_id: channel_id.to_proto(),
3708            collaborators: left_buffer.collaborators,
3709        },
3710        &session.peer,
3711    );
3712
3713    Ok(())
3714}
3715
3716fn channel_buffer_updated<T: EnvelopedMessage>(
3717    sender_id: ConnectionId,
3718    collaborators: impl IntoIterator<Item = ConnectionId>,
3719    message: &T,
3720    peer: &Peer,
3721) {
3722    broadcast(Some(sender_id), collaborators, |peer_id| {
3723        peer.send(peer_id, message.clone())
3724    });
3725}
3726
3727fn send_notifications(
3728    connection_pool: &ConnectionPool,
3729    peer: &Peer,
3730    notifications: db::NotificationBatch,
3731) {
3732    for (user_id, notification) in notifications {
3733        for connection_id in connection_pool.user_connection_ids(user_id) {
3734            if let Err(error) = peer.send(
3735                connection_id,
3736                proto::AddNotification {
3737                    notification: Some(notification.clone()),
3738                },
3739            ) {
3740                tracing::error!(
3741                    "failed to send notification to {:?} {}",
3742                    connection_id,
3743                    error
3744                );
3745            }
3746        }
3747    }
3748}
3749
3750/// Send a message to the channel
3751async fn send_channel_message(
3752    request: proto::SendChannelMessage,
3753    response: Response<proto::SendChannelMessage>,
3754    session: Session,
3755) -> Result<()> {
3756    // Validate the message body.
3757    let body = request.body.trim().to_string();
3758    if body.len() > MAX_MESSAGE_LEN {
3759        return Err(anyhow!("message is too long"))?;
3760    }
3761    if body.is_empty() {
3762        return Err(anyhow!("message can't be blank"))?;
3763    }
3764
3765    // TODO: adjust mentions if body is trimmed
3766
3767    let timestamp = OffsetDateTime::now_utc();
3768    let nonce = request.nonce.context("nonce can't be blank")?;
3769
3770    let channel_id = ChannelId::from_proto(request.channel_id);
3771    let CreatedChannelMessage {
3772        message_id,
3773        participant_connection_ids,
3774        notifications,
3775    } = session
3776        .db()
3777        .await
3778        .create_channel_message(
3779            channel_id,
3780            session.user_id(),
3781            &body,
3782            &request.mentions,
3783            timestamp,
3784            nonce.clone().into(),
3785            request.reply_to_message_id.map(MessageId::from_proto),
3786        )
3787        .await?;
3788
3789    let message = proto::ChannelMessage {
3790        sender_id: session.user_id().to_proto(),
3791        id: message_id.to_proto(),
3792        body,
3793        mentions: request.mentions,
3794        timestamp: timestamp.unix_timestamp() as u64,
3795        nonce: Some(nonce),
3796        reply_to_message_id: request.reply_to_message_id,
3797        edited_at: None,
3798    };
3799    broadcast(
3800        Some(session.connection_id),
3801        participant_connection_ids.clone(),
3802        |connection| {
3803            session.peer.send(
3804                connection,
3805                proto::ChannelMessageSent {
3806                    channel_id: channel_id.to_proto(),
3807                    message: Some(message.clone()),
3808                },
3809            )
3810        },
3811    );
3812    response.send(proto::SendChannelMessageResponse {
3813        message: Some(message),
3814    })?;
3815
3816    let pool = &*session.connection_pool().await;
3817    let non_participants =
3818        pool.channel_connection_ids(channel_id)
3819            .filter_map(|(connection_id, _)| {
3820                if participant_connection_ids.contains(&connection_id) {
3821                    None
3822                } else {
3823                    Some(connection_id)
3824                }
3825            });
3826    broadcast(None, non_participants, |peer_id| {
3827        session.peer.send(
3828            peer_id,
3829            proto::UpdateChannels {
3830                latest_channel_message_ids: vec![proto::ChannelMessageId {
3831                    channel_id: channel_id.to_proto(),
3832                    message_id: message_id.to_proto(),
3833                }],
3834                ..Default::default()
3835            },
3836        )
3837    });
3838    send_notifications(pool, &session.peer, notifications);
3839
3840    Ok(())
3841}
3842
3843/// Delete a channel message
3844async fn remove_channel_message(
3845    request: proto::RemoveChannelMessage,
3846    response: Response<proto::RemoveChannelMessage>,
3847    session: Session,
3848) -> Result<()> {
3849    let channel_id = ChannelId::from_proto(request.channel_id);
3850    let message_id = MessageId::from_proto(request.message_id);
3851    let (connection_ids, existing_notification_ids) = session
3852        .db()
3853        .await
3854        .remove_channel_message(channel_id, message_id, session.user_id())
3855        .await?;
3856
3857    broadcast(
3858        Some(session.connection_id),
3859        connection_ids,
3860        move |connection| {
3861            session.peer.send(connection, request.clone())?;
3862
3863            for notification_id in &existing_notification_ids {
3864                session.peer.send(
3865                    connection,
3866                    proto::DeleteNotification {
3867                        notification_id: (*notification_id).to_proto(),
3868                    },
3869                )?;
3870            }
3871
3872            Ok(())
3873        },
3874    );
3875    response.send(proto::Ack {})?;
3876    Ok(())
3877}
3878
3879async fn update_channel_message(
3880    request: proto::UpdateChannelMessage,
3881    response: Response<proto::UpdateChannelMessage>,
3882    session: Session,
3883) -> Result<()> {
3884    let channel_id = ChannelId::from_proto(request.channel_id);
3885    let message_id = MessageId::from_proto(request.message_id);
3886    let updated_at = OffsetDateTime::now_utc();
3887    let UpdatedChannelMessage {
3888        message_id,
3889        participant_connection_ids,
3890        notifications,
3891        reply_to_message_id,
3892        timestamp,
3893        deleted_mention_notification_ids,
3894        updated_mention_notifications,
3895    } = session
3896        .db()
3897        .await
3898        .update_channel_message(
3899            channel_id,
3900            message_id,
3901            session.user_id(),
3902            request.body.as_str(),
3903            &request.mentions,
3904            updated_at,
3905        )
3906        .await?;
3907
3908    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3909
3910    let message = proto::ChannelMessage {
3911        sender_id: session.user_id().to_proto(),
3912        id: message_id.to_proto(),
3913        body: request.body.clone(),
3914        mentions: request.mentions.clone(),
3915        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3916        nonce: Some(nonce),
3917        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3918        edited_at: Some(updated_at.unix_timestamp() as u64),
3919    };
3920
3921    response.send(proto::Ack {})?;
3922
3923    let pool = &*session.connection_pool().await;
3924    broadcast(
3925        Some(session.connection_id),
3926        participant_connection_ids,
3927        |connection| {
3928            session.peer.send(
3929                connection,
3930                proto::ChannelMessageUpdate {
3931                    channel_id: channel_id.to_proto(),
3932                    message: Some(message.clone()),
3933                },
3934            )?;
3935
3936            for notification_id in &deleted_mention_notification_ids {
3937                session.peer.send(
3938                    connection,
3939                    proto::DeleteNotification {
3940                        notification_id: (*notification_id).to_proto(),
3941                    },
3942                )?;
3943            }
3944
3945            for notification in &updated_mention_notifications {
3946                session.peer.send(
3947                    connection,
3948                    proto::UpdateNotification {
3949                        notification: Some(notification.clone()),
3950                    },
3951                )?;
3952            }
3953
3954            Ok(())
3955        },
3956    );
3957
3958    send_notifications(pool, &session.peer, notifications);
3959
3960    Ok(())
3961}
3962
3963/// Mark a channel message as read
3964async fn acknowledge_channel_message(
3965    request: proto::AckChannelMessage,
3966    session: Session,
3967) -> Result<()> {
3968    let channel_id = ChannelId::from_proto(request.channel_id);
3969    let message_id = MessageId::from_proto(request.message_id);
3970    let notifications = session
3971        .db()
3972        .await
3973        .observe_channel_message(channel_id, session.user_id(), message_id)
3974        .await?;
3975    send_notifications(
3976        &*session.connection_pool().await,
3977        &session.peer,
3978        notifications,
3979    );
3980    Ok(())
3981}
3982
3983/// Mark a buffer version as synced
3984async fn acknowledge_buffer_version(
3985    request: proto::AckBufferOperation,
3986    session: Session,
3987) -> Result<()> {
3988    let buffer_id = BufferId::from_proto(request.buffer_id);
3989    session
3990        .db()
3991        .await
3992        .observe_buffer_version(
3993            buffer_id,
3994            session.user_id(),
3995            request.epoch as i32,
3996            &request.version,
3997        )
3998        .await?;
3999    Ok(())
4000}
4001
4002/// Get a Supermaven API key for the user
4003async fn get_supermaven_api_key(
4004    _request: proto::GetSupermavenApiKey,
4005    response: Response<proto::GetSupermavenApiKey>,
4006    session: Session,
4007) -> Result<()> {
4008    let user_id: String = session.user_id().to_string();
4009    if !session.is_staff() {
4010        return Err(anyhow!("supermaven not enabled for this account"))?;
4011    }
4012
4013    let email = session.email().context("user must have an email")?;
4014
4015    let supermaven_admin_api = session
4016        .supermaven_client
4017        .as_ref()
4018        .context("supermaven not configured")?;
4019
4020    let result = supermaven_admin_api
4021        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4022        .await?;
4023
4024    response.send(proto::GetSupermavenApiKeyResponse {
4025        api_key: result.api_key,
4026    })?;
4027
4028    Ok(())
4029}
4030
4031/// Start receiving chat updates for a channel
4032async fn join_channel_chat(
4033    request: proto::JoinChannelChat,
4034    response: Response<proto::JoinChannelChat>,
4035    session: Session,
4036) -> Result<()> {
4037    let channel_id = ChannelId::from_proto(request.channel_id);
4038
4039    let db = session.db().await;
4040    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4041        .await?;
4042    let messages = db
4043        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4044        .await?;
4045    response.send(proto::JoinChannelChatResponse {
4046        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4047        messages,
4048    })?;
4049    Ok(())
4050}
4051
4052/// Stop receiving chat updates for a channel
4053async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4054    let channel_id = ChannelId::from_proto(request.channel_id);
4055    session
4056        .db()
4057        .await
4058        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4059        .await?;
4060    Ok(())
4061}
4062
4063/// Retrieve the chat history for a channel
4064async fn get_channel_messages(
4065    request: proto::GetChannelMessages,
4066    response: Response<proto::GetChannelMessages>,
4067    session: Session,
4068) -> Result<()> {
4069    let channel_id = ChannelId::from_proto(request.channel_id);
4070    let messages = session
4071        .db()
4072        .await
4073        .get_channel_messages(
4074            channel_id,
4075            session.user_id(),
4076            MESSAGE_COUNT_PER_PAGE,
4077            Some(MessageId::from_proto(request.before_message_id)),
4078        )
4079        .await?;
4080    response.send(proto::GetChannelMessagesResponse {
4081        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4082        messages,
4083    })?;
4084    Ok(())
4085}
4086
4087/// Retrieve specific chat messages
4088async fn get_channel_messages_by_id(
4089    request: proto::GetChannelMessagesById,
4090    response: Response<proto::GetChannelMessagesById>,
4091    session: Session,
4092) -> Result<()> {
4093    let message_ids = request
4094        .message_ids
4095        .iter()
4096        .map(|id| MessageId::from_proto(*id))
4097        .collect::<Vec<_>>();
4098    let messages = session
4099        .db()
4100        .await
4101        .get_channel_messages_by_id(session.user_id(), &message_ids)
4102        .await?;
4103    response.send(proto::GetChannelMessagesResponse {
4104        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4105        messages,
4106    })?;
4107    Ok(())
4108}
4109
4110/// Retrieve the current users notifications
4111async fn get_notifications(
4112    request: proto::GetNotifications,
4113    response: Response<proto::GetNotifications>,
4114    session: Session,
4115) -> Result<()> {
4116    let notifications = session
4117        .db()
4118        .await
4119        .get_notifications(
4120            session.user_id(),
4121            NOTIFICATION_COUNT_PER_PAGE,
4122            request.before_id.map(db::NotificationId::from_proto),
4123        )
4124        .await?;
4125    response.send(proto::GetNotificationsResponse {
4126        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4127        notifications,
4128    })?;
4129    Ok(())
4130}
4131
4132/// Mark notifications as read
4133async fn mark_notification_as_read(
4134    request: proto::MarkNotificationRead,
4135    response: Response<proto::MarkNotificationRead>,
4136    session: Session,
4137) -> Result<()> {
4138    let database = &session.db().await;
4139    let notifications = database
4140        .mark_notification_as_read_by_id(
4141            session.user_id(),
4142            NotificationId::from_proto(request.notification_id),
4143        )
4144        .await?;
4145    send_notifications(
4146        &*session.connection_pool().await,
4147        &session.peer,
4148        notifications,
4149    );
4150    response.send(proto::Ack {})?;
4151    Ok(())
4152}
4153
4154/// Get the current users information
4155async fn get_private_user_info(
4156    _request: proto::GetPrivateUserInfo,
4157    response: Response<proto::GetPrivateUserInfo>,
4158    session: Session,
4159) -> Result<()> {
4160    let db = session.db().await;
4161
4162    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4163    let user = db
4164        .get_user_by_id(session.user_id())
4165        .await?
4166        .context("user not found")?;
4167    let flags = db.get_user_flags(session.user_id()).await?;
4168
4169    response.send(proto::GetPrivateUserInfoResponse {
4170        metrics_id,
4171        staff: user.admin,
4172        flags,
4173        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4174    })?;
4175    Ok(())
4176}
4177
4178/// Accept the terms of service (tos) on behalf of the current user
4179async fn accept_terms_of_service(
4180    _request: proto::AcceptTermsOfService,
4181    response: Response<proto::AcceptTermsOfService>,
4182    session: Session,
4183) -> Result<()> {
4184    let db = session.db().await;
4185
4186    let accepted_tos_at = Utc::now();
4187    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4188        .await?;
4189
4190    response.send(proto::AcceptTermsOfServiceResponse {
4191        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4192    })?;
4193
4194    // When the user accepts the terms of service, we want to refresh their LLM
4195    // token to grant access.
4196    session
4197        .peer
4198        .send(session.connection_id, proto::RefreshLlmToken {})?;
4199
4200    Ok(())
4201}
4202
4203async fn get_llm_api_token(
4204    _request: proto::GetLlmToken,
4205    response: Response<proto::GetLlmToken>,
4206    session: Session,
4207) -> Result<()> {
4208    let db = session.db().await;
4209
4210    let flags = db.get_user_flags(session.user_id()).await?;
4211
4212    let user_id = session.user_id();
4213    let user = db
4214        .get_user_by_id(user_id)
4215        .await?
4216        .with_context(|| format!("user {user_id} not found"))?;
4217
4218    if user.accepted_tos_at.is_none() {
4219        Err(anyhow!("terms of service not accepted"))?
4220    }
4221
4222    let stripe_client = session
4223        .app_state
4224        .stripe_client
4225        .as_ref()
4226        .context("failed to retrieve Stripe client")?;
4227
4228    let stripe_billing = session
4229        .app_state
4230        .stripe_billing
4231        .as_ref()
4232        .context("failed to retrieve Stripe billing object")?;
4233
4234    let billing_customer = if let Some(billing_customer) =
4235        db.get_billing_customer_by_user_id(user.id).await?
4236    {
4237        billing_customer
4238    } else {
4239        let customer_id = stripe_billing
4240            .find_or_create_customer_by_email(user.email_address.as_deref())
4241            .await?;
4242
4243        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4244            .await?
4245            .context("billing customer not found")?
4246    };
4247
4248    let billing_subscription =
4249        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4250            billing_subscription
4251        } else {
4252            let stripe_customer_id =
4253                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4254
4255            let stripe_subscription = stripe_billing
4256                .subscribe_to_zed_free(stripe_customer_id)
4257                .await?;
4258
4259            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4260                billing_customer_id: billing_customer.id,
4261                kind: Some(SubscriptionKind::ZedFree),
4262                stripe_subscription_id: stripe_subscription.id.to_string(),
4263                stripe_subscription_status: stripe_subscription.status.into(),
4264                stripe_cancellation_reason: None,
4265                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4266                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4267            })
4268            .await?
4269        };
4270
4271    let billing_preferences = db.get_billing_preferences(user.id).await?;
4272
4273    let token = LlmTokenClaims::create(
4274        &user,
4275        session.is_staff(),
4276        billing_customer,
4277        billing_preferences,
4278        &flags,
4279        billing_subscription,
4280        session.system_id.clone(),
4281        &session.app_state.config,
4282    )?;
4283    response.send(proto::GetLlmTokenResponse { token })?;
4284    Ok(())
4285}
4286
4287fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4288    let message = match message {
4289        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4290        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4291        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4292        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4293        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4294            code: frame.code.into(),
4295            reason: frame.reason.as_str().to_owned().into(),
4296        })),
4297        // We should never receive a frame while reading the message, according
4298        // to the `tungstenite` maintainers:
4299        //
4300        // > It cannot occur when you read messages from the WebSocket, but it
4301        // > can be used when you want to send the raw frames (e.g. you want to
4302        // > send the frames to the WebSocket without composing the full message first).
4303        // >
4304        // > — https://github.com/snapview/tungstenite-rs/issues/268
4305        TungsteniteMessage::Frame(_) => {
4306            bail!("received an unexpected frame while reading the message")
4307        }
4308    };
4309
4310    Ok(message)
4311}
4312
4313fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4314    match message {
4315        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4316        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4317        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4318        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4319        AxumMessage::Close(frame) => {
4320            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4321                code: frame.code.into(),
4322                reason: frame.reason.as_ref().into(),
4323            }))
4324        }
4325    }
4326}
4327
4328fn notify_membership_updated(
4329    connection_pool: &mut ConnectionPool,
4330    result: MembershipUpdated,
4331    user_id: UserId,
4332    peer: &Peer,
4333) {
4334    for membership in &result.new_channels.channel_memberships {
4335        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4336    }
4337    for channel_id in &result.removed_channels {
4338        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4339    }
4340
4341    let user_channels_update = proto::UpdateUserChannels {
4342        channel_memberships: result
4343            .new_channels
4344            .channel_memberships
4345            .iter()
4346            .map(|cm| proto::ChannelMembership {
4347                channel_id: cm.channel_id.to_proto(),
4348                role: cm.role.into(),
4349            })
4350            .collect(),
4351        ..Default::default()
4352    };
4353
4354    let mut update = build_channels_update(result.new_channels);
4355    update.delete_channels = result
4356        .removed_channels
4357        .into_iter()
4358        .map(|id| id.to_proto())
4359        .collect();
4360    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4361
4362    for connection_id in connection_pool.user_connection_ids(user_id) {
4363        peer.send(connection_id, user_channels_update.clone())
4364            .trace_err();
4365        peer.send(connection_id, update.clone()).trace_err();
4366    }
4367}
4368
4369fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4370    proto::UpdateUserChannels {
4371        channel_memberships: channels
4372            .channel_memberships
4373            .iter()
4374            .map(|m| proto::ChannelMembership {
4375                channel_id: m.channel_id.to_proto(),
4376                role: m.role.into(),
4377            })
4378            .collect(),
4379        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4380        observed_channel_message_id: channels.observed_channel_messages.clone(),
4381    }
4382}
4383
4384fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4385    let mut update = proto::UpdateChannels::default();
4386
4387    for channel in channels.channels {
4388        update.channels.push(channel.to_proto());
4389    }
4390
4391    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4392    update.latest_channel_message_ids = channels.latest_channel_messages;
4393
4394    for (channel_id, participants) in channels.channel_participants {
4395        update
4396            .channel_participants
4397            .push(proto::ChannelParticipants {
4398                channel_id: channel_id.to_proto(),
4399                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4400            });
4401    }
4402
4403    for channel in channels.invited_channels {
4404        update.channel_invitations.push(channel.to_proto());
4405    }
4406
4407    update
4408}
4409
4410fn build_initial_contacts_update(
4411    contacts: Vec<db::Contact>,
4412    pool: &ConnectionPool,
4413) -> proto::UpdateContacts {
4414    let mut update = proto::UpdateContacts::default();
4415
4416    for contact in contacts {
4417        match contact {
4418            db::Contact::Accepted { user_id, busy } => {
4419                update.contacts.push(contact_for_user(user_id, busy, pool));
4420            }
4421            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4422            db::Contact::Incoming { user_id } => {
4423                update
4424                    .incoming_requests
4425                    .push(proto::IncomingContactRequest {
4426                        requester_id: user_id.to_proto(),
4427                    })
4428            }
4429        }
4430    }
4431
4432    update
4433}
4434
4435fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4436    proto::Contact {
4437        user_id: user_id.to_proto(),
4438        online: pool.is_user_online(user_id),
4439        busy,
4440    }
4441}
4442
4443fn room_updated(room: &proto::Room, peer: &Peer) {
4444    broadcast(
4445        None,
4446        room.participants
4447            .iter()
4448            .filter_map(|participant| Some(participant.peer_id?.into())),
4449        |peer_id| {
4450            peer.send(
4451                peer_id,
4452                proto::RoomUpdated {
4453                    room: Some(room.clone()),
4454                },
4455            )
4456        },
4457    );
4458}
4459
4460fn channel_updated(
4461    channel: &db::channel::Model,
4462    room: &proto::Room,
4463    peer: &Peer,
4464    pool: &ConnectionPool,
4465) {
4466    let participants = room
4467        .participants
4468        .iter()
4469        .map(|p| p.user_id)
4470        .collect::<Vec<_>>();
4471
4472    broadcast(
4473        None,
4474        pool.channel_connection_ids(channel.root_id())
4475            .filter_map(|(channel_id, role)| {
4476                role.can_see_channel(channel.visibility)
4477                    .then_some(channel_id)
4478            }),
4479        |peer_id| {
4480            peer.send(
4481                peer_id,
4482                proto::UpdateChannels {
4483                    channel_participants: vec![proto::ChannelParticipants {
4484                        channel_id: channel.id.to_proto(),
4485                        participant_user_ids: participants.clone(),
4486                    }],
4487                    ..Default::default()
4488                },
4489            )
4490        },
4491    );
4492}
4493
4494async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4495    let db = session.db().await;
4496
4497    let contacts = db.get_contacts(user_id).await?;
4498    let busy = db.is_user_busy(user_id).await?;
4499
4500    let pool = session.connection_pool().await;
4501    let updated_contact = contact_for_user(user_id, busy, &pool);
4502    for contact in contacts {
4503        if let db::Contact::Accepted {
4504            user_id: contact_user_id,
4505            ..
4506        } = contact
4507        {
4508            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4509                session
4510                    .peer
4511                    .send(
4512                        contact_conn_id,
4513                        proto::UpdateContacts {
4514                            contacts: vec![updated_contact.clone()],
4515                            remove_contacts: Default::default(),
4516                            incoming_requests: Default::default(),
4517                            remove_incoming_requests: Default::default(),
4518                            outgoing_requests: Default::default(),
4519                            remove_outgoing_requests: Default::default(),
4520                        },
4521                    )
4522                    .trace_err();
4523            }
4524        }
4525    }
4526    Ok(())
4527}
4528
4529async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4530    let mut contacts_to_update = HashSet::default();
4531
4532    let room_id;
4533    let canceled_calls_to_user_ids;
4534    let livekit_room;
4535    let delete_livekit_room;
4536    let room;
4537    let channel;
4538
4539    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4540        contacts_to_update.insert(session.user_id());
4541
4542        for project in left_room.left_projects.values() {
4543            project_left(project, session);
4544        }
4545
4546        room_id = RoomId::from_proto(left_room.room.id);
4547        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4548        livekit_room = mem::take(&mut left_room.room.livekit_room);
4549        delete_livekit_room = left_room.deleted;
4550        room = mem::take(&mut left_room.room);
4551        channel = mem::take(&mut left_room.channel);
4552
4553        room_updated(&room, &session.peer);
4554    } else {
4555        return Ok(());
4556    }
4557
4558    if let Some(channel) = channel {
4559        channel_updated(
4560            &channel,
4561            &room,
4562            &session.peer,
4563            &*session.connection_pool().await,
4564        );
4565    }
4566
4567    {
4568        let pool = session.connection_pool().await;
4569        for canceled_user_id in canceled_calls_to_user_ids {
4570            for connection_id in pool.user_connection_ids(canceled_user_id) {
4571                session
4572                    .peer
4573                    .send(
4574                        connection_id,
4575                        proto::CallCanceled {
4576                            room_id: room_id.to_proto(),
4577                        },
4578                    )
4579                    .trace_err();
4580            }
4581            contacts_to_update.insert(canceled_user_id);
4582        }
4583    }
4584
4585    for contact_user_id in contacts_to_update {
4586        update_user_contacts(contact_user_id, session).await?;
4587    }
4588
4589    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4590        live_kit
4591            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4592            .await
4593            .trace_err();
4594
4595        if delete_livekit_room {
4596            live_kit.delete_room(livekit_room).await.trace_err();
4597        }
4598    }
4599
4600    Ok(())
4601}
4602
4603async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4604    let left_channel_buffers = session
4605        .db()
4606        .await
4607        .leave_channel_buffers(session.connection_id)
4608        .await?;
4609
4610    for left_buffer in left_channel_buffers {
4611        channel_buffer_updated(
4612            session.connection_id,
4613            left_buffer.connections,
4614            &proto::UpdateChannelBufferCollaborators {
4615                channel_id: left_buffer.channel_id.to_proto(),
4616                collaborators: left_buffer.collaborators,
4617            },
4618            &session.peer,
4619        );
4620    }
4621
4622    Ok(())
4623}
4624
4625fn project_left(project: &db::LeftProject, session: &Session) {
4626    for connection_id in &project.connection_ids {
4627        if project.should_unshare {
4628            session
4629                .peer
4630                .send(
4631                    *connection_id,
4632                    proto::UnshareProject {
4633                        project_id: project.id.to_proto(),
4634                    },
4635                )
4636                .trace_err();
4637        } else {
4638            session
4639                .peer
4640                .send(
4641                    *connection_id,
4642                    proto::RemoveProjectCollaborator {
4643                        project_id: project.id.to_proto(),
4644                        peer_id: Some(session.connection_id.into()),
4645                    },
4646                )
4647                .trace_err();
4648        }
4649    }
4650}
4651
4652pub trait ResultExt {
4653    type Ok;
4654
4655    fn trace_err(self) -> Option<Self::Ok>;
4656}
4657
4658impl<T, E> ResultExt for Result<T, E>
4659where
4660    E: std::fmt::Debug,
4661{
4662    type Ok = T;
4663
4664    #[track_caller]
4665    fn trace_err(self) -> Option<T> {
4666        match self {
4667            Ok(value) => Some(value),
4668            Err(error) => {
4669                tracing::error!("{:?}", error);
4670                None
4671            }
4672        }
4673    }
4674}