rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{
   8    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
   9    MIN_ACCOUNT_AGE_FOR_LLM_USE,
  10};
  11use crate::stripe_client::StripeCustomerId;
  12use crate::{
  13    AppState, Error, Result, auth,
  14    db::{
  15        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  16        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  17        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  18        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  19    },
  20    executor::Executor,
  21};
  22use anyhow::{Context as _, anyhow, bail};
  23use async_tungstenite::tungstenite::{
  24    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  25};
  26use axum::{
  27    Extension, Router, TypedHeader,
  28    body::Body,
  29    extract::{
  30        ConnectInfo, WebSocketUpgrade,
  31        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  32    },
  33    headers::{Header, HeaderName},
  34    http::StatusCode,
  35    middleware,
  36    response::IntoResponse,
  37    routing::get,
  38};
  39use chrono::Utc;
  40use collections::{HashMap, HashSet};
  41pub use connection_pool::{ConnectionPool, ZedVersion};
  42use core::fmt::{self, Debug, Formatter};
  43use reqwest_client::ReqwestClient;
  44use rpc::proto::split_repository_update;
  45use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  46
  47use futures::{
  48    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  49    stream::FuturesUnordered,
  50};
  51use prometheus::{IntGauge, register_int_gauge};
  52use rpc::{
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54    proto::{
  55        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  56        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  57    },
  58};
  59use semantic_version::SemanticVersion;
  60use serde::{Serialize, Serializer};
  61use std::{
  62    any::TypeId,
  63    future::Future,
  64    marker::PhantomData,
  65    mem,
  66    net::SocketAddr,
  67    ops::{Deref, DerefMut},
  68    rc::Rc,
  69    sync::{
  70        Arc, OnceLock,
  71        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  72    },
  73    time::{Duration, Instant},
  74};
  75use time::OffsetDateTime;
  76use tokio::sync::{Semaphore, watch};
  77use tower::ServiceBuilder;
  78use tracing::{
  79    Instrument,
  80    field::{self},
  81    info_span, instrument,
  82};
  83
  84pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  85
  86// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  87pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  88
  89const MESSAGE_COUNT_PER_PAGE: usize = 100;
  90const MAX_MESSAGE_LEN: usize = 1024;
  91const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  92const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  93
  94static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  95
  96type MessageHandler =
  97    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  98
  99pub struct ConnectionGuard;
 100
 101impl ConnectionGuard {
 102    pub fn try_acquire() -> Result<Self, ()> {
 103        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 104        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 105            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 106            tracing::error!(
 107                "too many concurrent connections: {}",
 108                current_connections + 1
 109            );
 110            return Err(());
 111        }
 112        Ok(ConnectionGuard)
 113    }
 114}
 115
 116impl Drop for ConnectionGuard {
 117    fn drop(&mut self) {
 118        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 119    }
 120}
 121
 122struct Response<R> {
 123    peer: Arc<Peer>,
 124    receipt: Receipt<R>,
 125    responded: Arc<AtomicBool>,
 126}
 127
 128impl<R: RequestMessage> Response<R> {
 129    fn send(self, payload: R::Response) -> Result<()> {
 130        self.responded.store(true, SeqCst);
 131        self.peer.respond(self.receipt, payload)?;
 132        Ok(())
 133    }
 134}
 135
 136#[derive(Clone, Debug)]
 137pub enum Principal {
 138    User(User),
 139    Impersonated { user: User, admin: User },
 140}
 141
 142impl Principal {
 143    fn user(&self) -> &User {
 144        match self {
 145            Principal::User(user) => user,
 146            Principal::Impersonated { user, .. } => user,
 147        }
 148    }
 149
 150    fn update_span(&self, span: &tracing::Span) {
 151        match &self {
 152            Principal::User(user) => {
 153                span.record("user_id", user.id.0);
 154                span.record("login", &user.github_login);
 155            }
 156            Principal::Impersonated { user, admin } => {
 157                span.record("user_id", user.id.0);
 158                span.record("login", &user.github_login);
 159                span.record("impersonator", &admin.github_login);
 160            }
 161        }
 162    }
 163}
 164
 165#[derive(Clone)]
 166struct Session {
 167    principal: Principal,
 168    connection_id: ConnectionId,
 169    db: Arc<tokio::sync::Mutex<DbHandle>>,
 170    peer: Arc<Peer>,
 171    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 172    app_state: Arc<AppState>,
 173    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 174    /// The GeoIP country code for the user.
 175    #[allow(unused)]
 176    geoip_country_code: Option<String>,
 177    system_id: Option<String>,
 178    _executor: Executor,
 179}
 180
 181impl Session {
 182    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 183        #[cfg(test)]
 184        tokio::task::yield_now().await;
 185        let guard = self.db.lock().await;
 186        #[cfg(test)]
 187        tokio::task::yield_now().await;
 188        guard
 189    }
 190
 191    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 192        #[cfg(test)]
 193        tokio::task::yield_now().await;
 194        let guard = self.connection_pool.lock();
 195        ConnectionPoolGuard {
 196            guard,
 197            _not_send: PhantomData,
 198        }
 199    }
 200
 201    fn is_staff(&self) -> bool {
 202        match &self.principal {
 203            Principal::User(user) => user.admin,
 204            Principal::Impersonated { .. } => true,
 205        }
 206    }
 207
 208    fn user_id(&self) -> UserId {
 209        match &self.principal {
 210            Principal::User(user) => user.id,
 211            Principal::Impersonated { user, .. } => user.id,
 212        }
 213    }
 214
 215    pub fn email(&self) -> Option<String> {
 216        match &self.principal {
 217            Principal::User(user) => user.email_address.clone(),
 218            Principal::Impersonated { user, .. } => user.email_address.clone(),
 219        }
 220    }
 221}
 222
 223impl Debug for Session {
 224    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 225        let mut result = f.debug_struct("Session");
 226        match &self.principal {
 227            Principal::User(user) => {
 228                result.field("user", &user.github_login);
 229            }
 230            Principal::Impersonated { user, admin } => {
 231                result.field("user", &user.github_login);
 232                result.field("impersonator", &admin.github_login);
 233            }
 234        }
 235        result.field("connection_id", &self.connection_id).finish()
 236    }
 237}
 238
 239struct DbHandle(Arc<Database>);
 240
 241impl Deref for DbHandle {
 242    type Target = Database;
 243
 244    fn deref(&self) -> &Self::Target {
 245        self.0.as_ref()
 246    }
 247}
 248
 249pub struct Server {
 250    id: parking_lot::Mutex<ServerId>,
 251    peer: Arc<Peer>,
 252    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 253    app_state: Arc<AppState>,
 254    handlers: HashMap<TypeId, MessageHandler>,
 255    teardown: watch::Sender<bool>,
 256}
 257
 258pub(crate) struct ConnectionPoolGuard<'a> {
 259    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 260    _not_send: PhantomData<Rc<()>>,
 261}
 262
 263#[derive(Serialize)]
 264pub struct ServerSnapshot<'a> {
 265    peer: &'a Peer,
 266    #[serde(serialize_with = "serialize_deref")]
 267    connection_pool: ConnectionPoolGuard<'a>,
 268}
 269
 270pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 271where
 272    S: Serializer,
 273    T: Deref<Target = U>,
 274    U: Serialize,
 275{
 276    Serialize::serialize(value.deref(), serializer)
 277}
 278
 279impl Server {
 280    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 281        let mut server = Self {
 282            id: parking_lot::Mutex::new(id),
 283            peer: Peer::new(id.0 as u32),
 284            app_state: app_state.clone(),
 285            connection_pool: Default::default(),
 286            handlers: Default::default(),
 287            teardown: watch::channel(false).0,
 288        };
 289
 290        server
 291            .add_request_handler(ping)
 292            .add_request_handler(create_room)
 293            .add_request_handler(join_room)
 294            .add_request_handler(rejoin_room)
 295            .add_request_handler(leave_room)
 296            .add_request_handler(set_room_participant_role)
 297            .add_request_handler(call)
 298            .add_request_handler(cancel_call)
 299            .add_message_handler(decline_call)
 300            .add_request_handler(update_participant_location)
 301            .add_request_handler(share_project)
 302            .add_message_handler(unshare_project)
 303            .add_request_handler(join_project)
 304            .add_message_handler(leave_project)
 305            .add_request_handler(update_project)
 306            .add_request_handler(update_worktree)
 307            .add_request_handler(update_repository)
 308            .add_request_handler(remove_repository)
 309            .add_message_handler(start_language_server)
 310            .add_message_handler(update_language_server)
 311            .add_message_handler(update_diagnostic_summary)
 312            .add_message_handler(update_worktree_settings)
 313            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 314            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 316            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 317            .add_request_handler(forward_find_search_candidates_request)
 318            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 319            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 320            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 321            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 322            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 323            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 324            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 325            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 326            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 327            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 328            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 329            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 330            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 331            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 332            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 333            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 334            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 335            .add_request_handler(
 336                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 337            )
 338            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 339            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 340            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 341            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 342            .add_request_handler(
 343                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 344            )
 345            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 346            .add_request_handler(
 347                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 348            )
 349            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 350            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 353            )
 354            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 355            .add_request_handler(
 356                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 357            )
 358            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 359            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 360            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 361            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 362            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 363            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 364            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 365            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 366            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 367            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 368            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 369            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 370            .add_request_handler(
 371                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 372            )
 373            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 374            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 375            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 376            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 377            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 378            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 379            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 380            .add_message_handler(create_buffer_for_peer)
 381            .add_request_handler(update_buffer)
 382            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 383            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 384            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 385            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 386            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 387            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 388            .add_message_handler(
 389                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 390            )
 391            .add_request_handler(get_users)
 392            .add_request_handler(fuzzy_search_users)
 393            .add_request_handler(request_contact)
 394            .add_request_handler(remove_contact)
 395            .add_request_handler(respond_to_contact_request)
 396            .add_message_handler(subscribe_to_channels)
 397            .add_request_handler(create_channel)
 398            .add_request_handler(delete_channel)
 399            .add_request_handler(invite_channel_member)
 400            .add_request_handler(remove_channel_member)
 401            .add_request_handler(set_channel_member_role)
 402            .add_request_handler(set_channel_visibility)
 403            .add_request_handler(rename_channel)
 404            .add_request_handler(join_channel_buffer)
 405            .add_request_handler(leave_channel_buffer)
 406            .add_message_handler(update_channel_buffer)
 407            .add_request_handler(rejoin_channel_buffers)
 408            .add_request_handler(get_channel_members)
 409            .add_request_handler(respond_to_channel_invite)
 410            .add_request_handler(join_channel)
 411            .add_request_handler(join_channel_chat)
 412            .add_message_handler(leave_channel_chat)
 413            .add_request_handler(send_channel_message)
 414            .add_request_handler(remove_channel_message)
 415            .add_request_handler(update_channel_message)
 416            .add_request_handler(get_channel_messages)
 417            .add_request_handler(get_channel_messages_by_id)
 418            .add_request_handler(get_notifications)
 419            .add_request_handler(mark_notification_as_read)
 420            .add_request_handler(move_channel)
 421            .add_request_handler(reorder_channel)
 422            .add_request_handler(follow)
 423            .add_message_handler(unfollow)
 424            .add_message_handler(update_followers)
 425            .add_request_handler(get_private_user_info)
 426            .add_request_handler(get_llm_api_token)
 427            .add_request_handler(accept_terms_of_service)
 428            .add_message_handler(acknowledge_channel_message)
 429            .add_message_handler(acknowledge_buffer_version)
 430            .add_request_handler(get_supermaven_api_key)
 431            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 432            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 433            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 434            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 435            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 436            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 437            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 438            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 439            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 440            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 441            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 442            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 443            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 444            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 445            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 446            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 447            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 448            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 449            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 450            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 451            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 452            .add_message_handler(update_context);
 453
 454        Arc::new(server)
 455    }
 456
 457    pub async fn start(&self) -> Result<()> {
 458        let server_id = *self.id.lock();
 459        let app_state = self.app_state.clone();
 460        let peer = self.peer.clone();
 461        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 462        let pool = self.connection_pool.clone();
 463        let livekit_client = self.app_state.livekit_client.clone();
 464
 465        let span = info_span!("start server");
 466        self.app_state.executor.spawn_detached(
 467            async move {
 468                tracing::info!("waiting for cleanup timeout");
 469                timeout.await;
 470                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 471
 472                app_state
 473                    .db
 474                    .delete_stale_channel_chat_participants(
 475                        &app_state.config.zed_environment,
 476                        server_id,
 477                    )
 478                    .await
 479                    .trace_err();
 480
 481                if let Some((room_ids, channel_ids)) = app_state
 482                    .db
 483                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 484                    .await
 485                    .trace_err()
 486                {
 487                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 488                    tracing::info!(
 489                        stale_channel_buffer_count = channel_ids.len(),
 490                        "retrieved stale channel buffers"
 491                    );
 492
 493                    for channel_id in channel_ids {
 494                        if let Some(refreshed_channel_buffer) = app_state
 495                            .db
 496                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 497                            .await
 498                            .trace_err()
 499                        {
 500                            for connection_id in refreshed_channel_buffer.connection_ids {
 501                                peer.send(
 502                                    connection_id,
 503                                    proto::UpdateChannelBufferCollaborators {
 504                                        channel_id: channel_id.to_proto(),
 505                                        collaborators: refreshed_channel_buffer
 506                                            .collaborators
 507                                            .clone(),
 508                                    },
 509                                )
 510                                .trace_err();
 511                            }
 512                        }
 513                    }
 514
 515                    for room_id in room_ids {
 516                        let mut contacts_to_update = HashSet::default();
 517                        let mut canceled_calls_to_user_ids = Vec::new();
 518                        let mut livekit_room = String::new();
 519                        let mut delete_livekit_room = false;
 520
 521                        if let Some(mut refreshed_room) = app_state
 522                            .db
 523                            .clear_stale_room_participants(room_id, server_id)
 524                            .await
 525                            .trace_err()
 526                        {
 527                            tracing::info!(
 528                                room_id = room_id.0,
 529                                new_participant_count = refreshed_room.room.participants.len(),
 530                                "refreshed room"
 531                            );
 532                            room_updated(&refreshed_room.room, &peer);
 533                            if let Some(channel) = refreshed_room.channel.as_ref() {
 534                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 535                            }
 536                            contacts_to_update
 537                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 538                            contacts_to_update
 539                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 540                            canceled_calls_to_user_ids =
 541                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 542                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 543                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 544                        }
 545
 546                        {
 547                            let pool = pool.lock();
 548                            for canceled_user_id in canceled_calls_to_user_ids {
 549                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 550                                    peer.send(
 551                                        connection_id,
 552                                        proto::CallCanceled {
 553                                            room_id: room_id.to_proto(),
 554                                        },
 555                                    )
 556                                    .trace_err();
 557                                }
 558                            }
 559                        }
 560
 561                        for user_id in contacts_to_update {
 562                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 563                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 564                            if let Some((busy, contacts)) = busy.zip(contacts) {
 565                                let pool = pool.lock();
 566                                let updated_contact = contact_for_user(user_id, busy, &pool);
 567                                for contact in contacts {
 568                                    if let db::Contact::Accepted {
 569                                        user_id: contact_user_id,
 570                                        ..
 571                                    } = contact
 572                                    {
 573                                        for contact_conn_id in
 574                                            pool.user_connection_ids(contact_user_id)
 575                                        {
 576                                            peer.send(
 577                                                contact_conn_id,
 578                                                proto::UpdateContacts {
 579                                                    contacts: vec![updated_contact.clone()],
 580                                                    remove_contacts: Default::default(),
 581                                                    incoming_requests: Default::default(),
 582                                                    remove_incoming_requests: Default::default(),
 583                                                    outgoing_requests: Default::default(),
 584                                                    remove_outgoing_requests: Default::default(),
 585                                                },
 586                                            )
 587                                            .trace_err();
 588                                        }
 589                                    }
 590                                }
 591                            }
 592                        }
 593
 594                        if let Some(live_kit) = livekit_client.as_ref() {
 595                            if delete_livekit_room {
 596                                live_kit.delete_room(livekit_room).await.trace_err();
 597                            }
 598                        }
 599                    }
 600                }
 601
 602                app_state
 603                    .db
 604                    .delete_stale_channel_chat_participants(
 605                        &app_state.config.zed_environment,
 606                        server_id,
 607                    )
 608                    .await
 609                    .trace_err();
 610
 611                app_state
 612                    .db
 613                    .clear_old_worktree_entries(server_id)
 614                    .await
 615                    .trace_err();
 616
 617                app_state
 618                    .db
 619                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 620                    .await
 621                    .trace_err();
 622            }
 623            .instrument(span),
 624        );
 625        Ok(())
 626    }
 627
 628    pub fn teardown(&self) {
 629        self.peer.teardown();
 630        self.connection_pool.lock().reset();
 631        let _ = self.teardown.send(true);
 632    }
 633
 634    #[cfg(test)]
 635    pub fn reset(&self, id: ServerId) {
 636        self.teardown();
 637        *self.id.lock() = id;
 638        self.peer.reset(id.0 as u32);
 639        let _ = self.teardown.send(false);
 640    }
 641
 642    #[cfg(test)]
 643    pub fn id(&self) -> ServerId {
 644        *self.id.lock()
 645    }
 646
 647    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 648    where
 649        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 650        Fut: 'static + Send + Future<Output = Result<()>>,
 651        M: EnvelopedMessage,
 652    {
 653        let prev_handler = self.handlers.insert(
 654            TypeId::of::<M>(),
 655            Box::new(move |envelope, session| {
 656                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 657                let received_at = envelope.received_at;
 658                tracing::info!("message received");
 659                let start_time = Instant::now();
 660                let future = (handler)(*envelope, session);
 661                async move {
 662                    let result = future.await;
 663                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 664                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 665                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 666                    let payload_type = M::NAME;
 667
 668                    match result {
 669                        Err(error) => {
 670                            tracing::error!(
 671                                ?error,
 672                                total_duration_ms,
 673                                processing_duration_ms,
 674                                queue_duration_ms,
 675                                payload_type,
 676                                "error handling message"
 677                            )
 678                        }
 679                        Ok(()) => tracing::info!(
 680                            total_duration_ms,
 681                            processing_duration_ms,
 682                            queue_duration_ms,
 683                            "finished handling message"
 684                        ),
 685                    }
 686                }
 687                .boxed()
 688            }),
 689        );
 690        if prev_handler.is_some() {
 691            panic!("registered a handler for the same message twice");
 692        }
 693        self
 694    }
 695
 696    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 697    where
 698        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 699        Fut: 'static + Send + Future<Output = Result<()>>,
 700        M: EnvelopedMessage,
 701    {
 702        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 703        self
 704    }
 705
 706    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 707    where
 708        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 709        Fut: Send + Future<Output = Result<()>>,
 710        M: RequestMessage,
 711    {
 712        let handler = Arc::new(handler);
 713        self.add_handler(move |envelope, session| {
 714            let receipt = envelope.receipt();
 715            let handler = handler.clone();
 716            async move {
 717                let peer = session.peer.clone();
 718                let responded = Arc::new(AtomicBool::default());
 719                let response = Response {
 720                    peer: peer.clone(),
 721                    responded: responded.clone(),
 722                    receipt,
 723                };
 724                match (handler)(envelope.payload, response, session).await {
 725                    Ok(()) => {
 726                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 727                            Ok(())
 728                        } else {
 729                            Err(anyhow!("handler did not send a response"))?
 730                        }
 731                    }
 732                    Err(error) => {
 733                        let proto_err = match &error {
 734                            Error::Internal(err) => err.to_proto(),
 735                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 736                        };
 737                        peer.respond_with_error(receipt, proto_err)?;
 738                        Err(error)
 739                    }
 740                }
 741            }
 742        })
 743    }
 744
 745    pub fn handle_connection(
 746        self: &Arc<Self>,
 747        connection: Connection,
 748        address: String,
 749        principal: Principal,
 750        zed_version: ZedVersion,
 751        geoip_country_code: Option<String>,
 752        system_id: Option<String>,
 753        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 754        executor: Executor,
 755        connection_guard: Option<ConnectionGuard>,
 756    ) -> impl Future<Output = ()> + use<> {
 757        let this = self.clone();
 758        let span = info_span!("handle connection", %address,
 759            connection_id=field::Empty,
 760            user_id=field::Empty,
 761            login=field::Empty,
 762            impersonator=field::Empty,
 763            geoip_country_code=field::Empty
 764        );
 765        principal.update_span(&span);
 766        if let Some(country_code) = geoip_country_code.as_ref() {
 767            span.record("geoip_country_code", country_code);
 768        }
 769
 770        let mut teardown = self.teardown.subscribe();
 771        async move {
 772            if *teardown.borrow() {
 773                tracing::error!("server is tearing down");
 774                return
 775            }
 776
 777            let (connection_id, handle_io, mut incoming_rx) = this
 778                .peer
 779                .add_connection(connection, {
 780                    let executor = executor.clone();
 781                    move |duration| executor.sleep(duration)
 782                });
 783            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 784
 785            tracing::info!("connection opened");
 786
 787            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 788            let http_client = match ReqwestClient::user_agent(&user_agent) {
 789                Ok(http_client) => Arc::new(http_client),
 790                Err(error) => {
 791                    tracing::error!(?error, "failed to create HTTP client");
 792                    return;
 793                }
 794            };
 795
 796            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 797                    supermaven_admin_api_key.to_string(),
 798                    http_client.clone(),
 799                )));
 800
 801            let session = Session {
 802                principal: principal.clone(),
 803                connection_id,
 804                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 805                peer: this.peer.clone(),
 806                connection_pool: this.connection_pool.clone(),
 807                app_state: this.app_state.clone(),
 808                geoip_country_code,
 809                system_id,
 810                _executor: executor.clone(),
 811                supermaven_client,
 812            };
 813
 814            if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
 815                tracing::error!(?error, "failed to send initial client update");
 816                return;
 817            }
 818            drop(connection_guard);
 819
 820            let handle_io = handle_io.fuse();
 821            futures::pin_mut!(handle_io);
 822
 823            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 824            // This prevents deadlocks when e.g., client A performs a request to client B and
 825            // client B performs a request to client A. If both clients stop processing further
 826            // messages until their respective request completes, they won't have a chance to
 827            // respond to the other client's request and cause a deadlock.
 828            //
 829            // This arrangement ensures we will attempt to process earlier messages first, but fall
 830            // back to processing messages arrived later in the spirit of making progress.
 831            let mut foreground_message_handlers = FuturesUnordered::new();
 832            let concurrent_handlers = Arc::new(Semaphore::new(512));
 833            loop {
 834                let next_message = async {
 835                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 836                    let message = incoming_rx.next().await;
 837                    (permit, message)
 838                }.fuse();
 839                futures::pin_mut!(next_message);
 840                futures::select_biased! {
 841                    _ = teardown.changed().fuse() => return,
 842                    result = handle_io => {
 843                        if let Err(error) = result {
 844                            tracing::error!(?error, "error handling I/O");
 845                        }
 846                        break;
 847                    }
 848                    _ = foreground_message_handlers.next() => {}
 849                    next_message = next_message => {
 850                        let (permit, message) = next_message;
 851                        if let Some(message) = message {
 852                            let type_name = message.payload_type_name();
 853                            // note: we copy all the fields from the parent span so we can query them in the logs.
 854                            // (https://github.com/tokio-rs/tracing/issues/2670).
 855                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 856                                user_id=field::Empty,
 857                                login=field::Empty,
 858                                impersonator=field::Empty,
 859                            );
 860                            principal.update_span(&span);
 861                            let span_enter = span.enter();
 862                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 863                                let is_background = message.is_background();
 864                                let handle_message = (handler)(message, session.clone());
 865                                drop(span_enter);
 866
 867                                let handle_message = async move {
 868                                    handle_message.await;
 869                                    drop(permit);
 870                                }.instrument(span);
 871                                if is_background {
 872                                    executor.spawn_detached(handle_message);
 873                                } else {
 874                                    foreground_message_handlers.push(handle_message);
 875                                }
 876                            } else {
 877                                tracing::error!("no message handler");
 878                            }
 879                        } else {
 880                            tracing::info!("connection closed");
 881                            break;
 882                        }
 883                    }
 884                }
 885            }
 886
 887            drop(foreground_message_handlers);
 888            tracing::info!("signing out");
 889            if let Err(error) = connection_lost(session, teardown, executor).await {
 890                tracing::error!(?error, "error signing out");
 891            }
 892
 893        }.instrument(span)
 894    }
 895
 896    async fn send_initial_client_update(
 897        &self,
 898        connection_id: ConnectionId,
 899        zed_version: ZedVersion,
 900        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 901        session: &Session,
 902    ) -> Result<()> {
 903        self.peer.send(
 904            connection_id,
 905            proto::Hello {
 906                peer_id: Some(connection_id.into()),
 907            },
 908        )?;
 909        tracing::info!("sent hello message");
 910        if let Some(send_connection_id) = send_connection_id.take() {
 911            let _ = send_connection_id.send(connection_id);
 912        }
 913
 914        match &session.principal {
 915            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 916                if !user.connected_once {
 917                    self.peer.send(connection_id, proto::ShowContacts {})?;
 918                    self.app_state
 919                        .db
 920                        .set_user_connected_once(user.id, true)
 921                        .await?;
 922                }
 923
 924                update_user_plan(session).await?;
 925
 926                let contacts = self.app_state.db.get_contacts(user.id).await?;
 927
 928                {
 929                    let mut pool = self.connection_pool.lock();
 930                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 931                    self.peer.send(
 932                        connection_id,
 933                        build_initial_contacts_update(contacts, &pool),
 934                    )?;
 935                }
 936
 937                if should_auto_subscribe_to_channels(zed_version) {
 938                    subscribe_user_to_channels(user.id, session).await?;
 939                }
 940
 941                if let Some(incoming_call) =
 942                    self.app_state.db.incoming_call_for_user(user.id).await?
 943                {
 944                    self.peer.send(connection_id, incoming_call)?;
 945                }
 946
 947                update_user_contacts(user.id, session).await?;
 948            }
 949        }
 950
 951        Ok(())
 952    }
 953
 954    pub async fn invite_code_redeemed(
 955        self: &Arc<Self>,
 956        inviter_id: UserId,
 957        invitee_id: UserId,
 958    ) -> Result<()> {
 959        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 960            if let Some(code) = &user.invite_code {
 961                let pool = self.connection_pool.lock();
 962                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 963                for connection_id in pool.user_connection_ids(inviter_id) {
 964                    self.peer.send(
 965                        connection_id,
 966                        proto::UpdateContacts {
 967                            contacts: vec![invitee_contact.clone()],
 968                            ..Default::default()
 969                        },
 970                    )?;
 971                    self.peer.send(
 972                        connection_id,
 973                        proto::UpdateInviteInfo {
 974                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 975                            count: user.invite_count as u32,
 976                        },
 977                    )?;
 978                }
 979            }
 980        }
 981        Ok(())
 982    }
 983
 984    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 985        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 986            if let Some(invite_code) = &user.invite_code {
 987                let pool = self.connection_pool.lock();
 988                for connection_id in pool.user_connection_ids(user_id) {
 989                    self.peer.send(
 990                        connection_id,
 991                        proto::UpdateInviteInfo {
 992                            url: format!(
 993                                "{}{}",
 994                                self.app_state.config.invite_link_prefix, invite_code
 995                            ),
 996                            count: user.invite_count as u32,
 997                        },
 998                    )?;
 999                }
1000            }
1001        }
1002        Ok(())
1003    }
1004
1005    pub async fn update_plan_for_user(
1006        self: &Arc<Self>,
1007        user_id: UserId,
1008        update_user_plan: proto::UpdateUserPlan,
1009    ) -> Result<()> {
1010        let pool = self.connection_pool.lock();
1011        for connection_id in pool.user_connection_ids(user_id) {
1012            self.peer
1013                .send(connection_id, update_user_plan.clone())
1014                .trace_err();
1015        }
1016
1017        Ok(())
1018    }
1019
1020    /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1021    /// message on the Collab server.
1022    ///
1023    /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1024    pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1025        let user = self
1026            .app_state
1027            .db
1028            .get_user_by_id(user_id)
1029            .await?
1030            .context("user not found")?;
1031
1032        let update_user_plan = make_update_user_plan_message(
1033            &user,
1034            user.admin,
1035            &self.app_state.db,
1036            self.app_state.llm_db.clone(),
1037        )
1038        .await?;
1039
1040        self.update_plan_for_user(user_id, update_user_plan).await
1041    }
1042
1043    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1044        let pool = self.connection_pool.lock();
1045        for connection_id in pool.user_connection_ids(user_id) {
1046            self.peer
1047                .send(connection_id, proto::RefreshLlmToken {})
1048                .trace_err();
1049        }
1050    }
1051
1052    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1053        ServerSnapshot {
1054            connection_pool: ConnectionPoolGuard {
1055                guard: self.connection_pool.lock(),
1056                _not_send: PhantomData,
1057            },
1058            peer: &self.peer,
1059        }
1060    }
1061}
1062
1063impl Deref for ConnectionPoolGuard<'_> {
1064    type Target = ConnectionPool;
1065
1066    fn deref(&self) -> &Self::Target {
1067        &self.guard
1068    }
1069}
1070
1071impl DerefMut for ConnectionPoolGuard<'_> {
1072    fn deref_mut(&mut self) -> &mut Self::Target {
1073        &mut self.guard
1074    }
1075}
1076
1077impl Drop for ConnectionPoolGuard<'_> {
1078    fn drop(&mut self) {
1079        #[cfg(test)]
1080        self.check_invariants();
1081    }
1082}
1083
1084fn broadcast<F>(
1085    sender_id: Option<ConnectionId>,
1086    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1087    mut f: F,
1088) where
1089    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1090{
1091    for receiver_id in receiver_ids {
1092        if Some(receiver_id) != sender_id {
1093            if let Err(error) = f(receiver_id) {
1094                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1095            }
1096        }
1097    }
1098}
1099
1100pub struct ProtocolVersion(u32);
1101
1102impl Header for ProtocolVersion {
1103    fn name() -> &'static HeaderName {
1104        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1105        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1106    }
1107
1108    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1109    where
1110        Self: Sized,
1111        I: Iterator<Item = &'i axum::http::HeaderValue>,
1112    {
1113        let version = values
1114            .next()
1115            .ok_or_else(axum::headers::Error::invalid)?
1116            .to_str()
1117            .map_err(|_| axum::headers::Error::invalid())?
1118            .parse()
1119            .map_err(|_| axum::headers::Error::invalid())?;
1120        Ok(Self(version))
1121    }
1122
1123    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1124        values.extend([self.0.to_string().parse().unwrap()]);
1125    }
1126}
1127
1128pub struct AppVersionHeader(SemanticVersion);
1129impl Header for AppVersionHeader {
1130    fn name() -> &'static HeaderName {
1131        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1132        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1133    }
1134
1135    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1136    where
1137        Self: Sized,
1138        I: Iterator<Item = &'i axum::http::HeaderValue>,
1139    {
1140        let version = values
1141            .next()
1142            .ok_or_else(axum::headers::Error::invalid)?
1143            .to_str()
1144            .map_err(|_| axum::headers::Error::invalid())?
1145            .parse()
1146            .map_err(|_| axum::headers::Error::invalid())?;
1147        Ok(Self(version))
1148    }
1149
1150    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1151        values.extend([self.0.to_string().parse().unwrap()]);
1152    }
1153}
1154
1155pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1156    Router::new()
1157        .route("/rpc", get(handle_websocket_request))
1158        .layer(
1159            ServiceBuilder::new()
1160                .layer(Extension(server.app_state.clone()))
1161                .layer(middleware::from_fn(auth::validate_header)),
1162        )
1163        .route("/metrics", get(handle_metrics))
1164        .layer(Extension(server))
1165}
1166
1167pub async fn handle_websocket_request(
1168    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1169    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1170    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1171    Extension(server): Extension<Arc<Server>>,
1172    Extension(principal): Extension<Principal>,
1173    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1174    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1175    ws: WebSocketUpgrade,
1176) -> axum::response::Response {
1177    if protocol_version != rpc::PROTOCOL_VERSION {
1178        return (
1179            StatusCode::UPGRADE_REQUIRED,
1180            "client must be upgraded".to_string(),
1181        )
1182            .into_response();
1183    }
1184
1185    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1186        return (
1187            StatusCode::UPGRADE_REQUIRED,
1188            "no version header found".to_string(),
1189        )
1190            .into_response();
1191    };
1192
1193    if !version.can_collaborate() {
1194        return (
1195            StatusCode::UPGRADE_REQUIRED,
1196            "client must be upgraded".to_string(),
1197        )
1198            .into_response();
1199    }
1200
1201    let socket_address = socket_address.to_string();
1202
1203    // Acquire connection guard before WebSocket upgrade
1204    let connection_guard = match ConnectionGuard::try_acquire() {
1205        Ok(guard) => guard,
1206        Err(()) => {
1207            return (
1208                StatusCode::SERVICE_UNAVAILABLE,
1209                "Too many concurrent connections",
1210            )
1211                .into_response();
1212        }
1213    };
1214
1215    ws.on_upgrade(move |socket| {
1216        let socket = socket
1217            .map_ok(to_tungstenite_message)
1218            .err_into()
1219            .with(|message| async move { to_axum_message(message) });
1220        let connection = Connection::new(Box::pin(socket));
1221        async move {
1222            server
1223                .handle_connection(
1224                    connection,
1225                    socket_address,
1226                    principal,
1227                    version,
1228                    country_code_header.map(|header| header.to_string()),
1229                    system_id_header.map(|header| header.to_string()),
1230                    None,
1231                    Executor::Production,
1232                    Some(connection_guard),
1233                )
1234                .await;
1235        }
1236    })
1237}
1238
1239pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1240    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1241    let connections_metric = CONNECTIONS_METRIC
1242        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1243
1244    let connections = server
1245        .connection_pool
1246        .lock()
1247        .connections()
1248        .filter(|connection| !connection.admin)
1249        .count();
1250    connections_metric.set(connections as _);
1251
1252    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1253    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1254        register_int_gauge!(
1255            "shared_projects",
1256            "number of open projects with one or more guests"
1257        )
1258        .unwrap()
1259    });
1260
1261    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1262    shared_projects_metric.set(shared_projects as _);
1263
1264    let encoder = prometheus::TextEncoder::new();
1265    let metric_families = prometheus::gather();
1266    let encoded_metrics = encoder
1267        .encode_to_string(&metric_families)
1268        .map_err(|err| anyhow!("{err}"))?;
1269    Ok(encoded_metrics)
1270}
1271
1272#[instrument(err, skip(executor))]
1273async fn connection_lost(
1274    session: Session,
1275    mut teardown: watch::Receiver<bool>,
1276    executor: Executor,
1277) -> Result<()> {
1278    session.peer.disconnect(session.connection_id);
1279    session
1280        .connection_pool()
1281        .await
1282        .remove_connection(session.connection_id)?;
1283
1284    session
1285        .db()
1286        .await
1287        .connection_lost(session.connection_id)
1288        .await
1289        .trace_err();
1290
1291    futures::select_biased! {
1292        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1293
1294            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1295            leave_room_for_session(&session, session.connection_id).await.trace_err();
1296            leave_channel_buffers_for_session(&session)
1297                .await
1298                .trace_err();
1299
1300            if !session
1301                .connection_pool()
1302                .await
1303                .is_user_online(session.user_id())
1304            {
1305                let db = session.db().await;
1306                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1307                    room_updated(&room, &session.peer);
1308                }
1309            }
1310
1311            update_user_contacts(session.user_id(), &session).await?;
1312        },
1313        _ = teardown.changed().fuse() => {}
1314    }
1315
1316    Ok(())
1317}
1318
1319/// Acknowledges a ping from a client, used to keep the connection alive.
1320async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1321    response.send(proto::Ack {})?;
1322    Ok(())
1323}
1324
1325/// Creates a new room for calling (outside of channels)
1326async fn create_room(
1327    _request: proto::CreateRoom,
1328    response: Response<proto::CreateRoom>,
1329    session: Session,
1330) -> Result<()> {
1331    let livekit_room = nanoid::nanoid!(30);
1332
1333    let live_kit_connection_info = util::maybe!(async {
1334        let live_kit = session.app_state.livekit_client.as_ref();
1335        let live_kit = live_kit?;
1336        let user_id = session.user_id().to_string();
1337
1338        let token = live_kit
1339            .room_token(&livekit_room, &user_id.to_string())
1340            .trace_err()?;
1341
1342        Some(proto::LiveKitConnectionInfo {
1343            server_url: live_kit.url().into(),
1344            token,
1345            can_publish: true,
1346        })
1347    })
1348    .await;
1349
1350    let room = session
1351        .db()
1352        .await
1353        .create_room(session.user_id(), session.connection_id, &livekit_room)
1354        .await?;
1355
1356    response.send(proto::CreateRoomResponse {
1357        room: Some(room.clone()),
1358        live_kit_connection_info,
1359    })?;
1360
1361    update_user_contacts(session.user_id(), &session).await?;
1362    Ok(())
1363}
1364
1365/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1366async fn join_room(
1367    request: proto::JoinRoom,
1368    response: Response<proto::JoinRoom>,
1369    session: Session,
1370) -> Result<()> {
1371    let room_id = RoomId::from_proto(request.id);
1372
1373    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1374
1375    if let Some(channel_id) = channel_id {
1376        return join_channel_internal(channel_id, Box::new(response), session).await;
1377    }
1378
1379    let joined_room = {
1380        let room = session
1381            .db()
1382            .await
1383            .join_room(room_id, session.user_id(), session.connection_id)
1384            .await?;
1385        room_updated(&room.room, &session.peer);
1386        room.into_inner()
1387    };
1388
1389    for connection_id in session
1390        .connection_pool()
1391        .await
1392        .user_connection_ids(session.user_id())
1393    {
1394        session
1395            .peer
1396            .send(
1397                connection_id,
1398                proto::CallCanceled {
1399                    room_id: room_id.to_proto(),
1400                },
1401            )
1402            .trace_err();
1403    }
1404
1405    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1406    {
1407        live_kit
1408            .room_token(
1409                &joined_room.room.livekit_room,
1410                &session.user_id().to_string(),
1411            )
1412            .trace_err()
1413            .map(|token| proto::LiveKitConnectionInfo {
1414                server_url: live_kit.url().into(),
1415                token,
1416                can_publish: true,
1417            })
1418    } else {
1419        None
1420    };
1421
1422    response.send(proto::JoinRoomResponse {
1423        room: Some(joined_room.room),
1424        channel_id: None,
1425        live_kit_connection_info,
1426    })?;
1427
1428    update_user_contacts(session.user_id(), &session).await?;
1429    Ok(())
1430}
1431
1432/// Rejoin room is used to reconnect to a room after connection errors.
1433async fn rejoin_room(
1434    request: proto::RejoinRoom,
1435    response: Response<proto::RejoinRoom>,
1436    session: Session,
1437) -> Result<()> {
1438    let room;
1439    let channel;
1440    {
1441        let mut rejoined_room = session
1442            .db()
1443            .await
1444            .rejoin_room(request, session.user_id(), session.connection_id)
1445            .await?;
1446
1447        response.send(proto::RejoinRoomResponse {
1448            room: Some(rejoined_room.room.clone()),
1449            reshared_projects: rejoined_room
1450                .reshared_projects
1451                .iter()
1452                .map(|project| proto::ResharedProject {
1453                    id: project.id.to_proto(),
1454                    collaborators: project
1455                        .collaborators
1456                        .iter()
1457                        .map(|collaborator| collaborator.to_proto())
1458                        .collect(),
1459                })
1460                .collect(),
1461            rejoined_projects: rejoined_room
1462                .rejoined_projects
1463                .iter()
1464                .map(|rejoined_project| rejoined_project.to_proto())
1465                .collect(),
1466        })?;
1467        room_updated(&rejoined_room.room, &session.peer);
1468
1469        for project in &rejoined_room.reshared_projects {
1470            for collaborator in &project.collaborators {
1471                session
1472                    .peer
1473                    .send(
1474                        collaborator.connection_id,
1475                        proto::UpdateProjectCollaborator {
1476                            project_id: project.id.to_proto(),
1477                            old_peer_id: Some(project.old_connection_id.into()),
1478                            new_peer_id: Some(session.connection_id.into()),
1479                        },
1480                    )
1481                    .trace_err();
1482            }
1483
1484            broadcast(
1485                Some(session.connection_id),
1486                project
1487                    .collaborators
1488                    .iter()
1489                    .map(|collaborator| collaborator.connection_id),
1490                |connection_id| {
1491                    session.peer.forward_send(
1492                        session.connection_id,
1493                        connection_id,
1494                        proto::UpdateProject {
1495                            project_id: project.id.to_proto(),
1496                            worktrees: project.worktrees.clone(),
1497                        },
1498                    )
1499                },
1500            );
1501        }
1502
1503        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1504
1505        let rejoined_room = rejoined_room.into_inner();
1506
1507        room = rejoined_room.room;
1508        channel = rejoined_room.channel;
1509    }
1510
1511    if let Some(channel) = channel {
1512        channel_updated(
1513            &channel,
1514            &room,
1515            &session.peer,
1516            &*session.connection_pool().await,
1517        );
1518    }
1519
1520    update_user_contacts(session.user_id(), &session).await?;
1521    Ok(())
1522}
1523
1524fn notify_rejoined_projects(
1525    rejoined_projects: &mut Vec<RejoinedProject>,
1526    session: &Session,
1527) -> Result<()> {
1528    for project in rejoined_projects.iter() {
1529        for collaborator in &project.collaborators {
1530            session
1531                .peer
1532                .send(
1533                    collaborator.connection_id,
1534                    proto::UpdateProjectCollaborator {
1535                        project_id: project.id.to_proto(),
1536                        old_peer_id: Some(project.old_connection_id.into()),
1537                        new_peer_id: Some(session.connection_id.into()),
1538                    },
1539                )
1540                .trace_err();
1541        }
1542    }
1543
1544    for project in rejoined_projects {
1545        for worktree in mem::take(&mut project.worktrees) {
1546            // Stream this worktree's entries.
1547            let message = proto::UpdateWorktree {
1548                project_id: project.id.to_proto(),
1549                worktree_id: worktree.id,
1550                abs_path: worktree.abs_path.clone(),
1551                root_name: worktree.root_name,
1552                updated_entries: worktree.updated_entries,
1553                removed_entries: worktree.removed_entries,
1554                scan_id: worktree.scan_id,
1555                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1556                updated_repositories: worktree.updated_repositories,
1557                removed_repositories: worktree.removed_repositories,
1558            };
1559            for update in proto::split_worktree_update(message) {
1560                session.peer.send(session.connection_id, update)?;
1561            }
1562
1563            // Stream this worktree's diagnostics.
1564            for summary in worktree.diagnostic_summaries {
1565                session.peer.send(
1566                    session.connection_id,
1567                    proto::UpdateDiagnosticSummary {
1568                        project_id: project.id.to_proto(),
1569                        worktree_id: worktree.id,
1570                        summary: Some(summary),
1571                    },
1572                )?;
1573            }
1574
1575            for settings_file in worktree.settings_files {
1576                session.peer.send(
1577                    session.connection_id,
1578                    proto::UpdateWorktreeSettings {
1579                        project_id: project.id.to_proto(),
1580                        worktree_id: worktree.id,
1581                        path: settings_file.path,
1582                        content: Some(settings_file.content),
1583                        kind: Some(settings_file.kind.to_proto().into()),
1584                    },
1585                )?;
1586            }
1587        }
1588
1589        for repository in mem::take(&mut project.updated_repositories) {
1590            for update in split_repository_update(repository) {
1591                session.peer.send(session.connection_id, update)?;
1592            }
1593        }
1594
1595        for id in mem::take(&mut project.removed_repositories) {
1596            session.peer.send(
1597                session.connection_id,
1598                proto::RemoveRepository {
1599                    project_id: project.id.to_proto(),
1600                    id,
1601                },
1602            )?;
1603        }
1604    }
1605
1606    Ok(())
1607}
1608
1609/// leave room disconnects from the room.
1610async fn leave_room(
1611    _: proto::LeaveRoom,
1612    response: Response<proto::LeaveRoom>,
1613    session: Session,
1614) -> Result<()> {
1615    leave_room_for_session(&session, session.connection_id).await?;
1616    response.send(proto::Ack {})?;
1617    Ok(())
1618}
1619
1620/// Updates the permissions of someone else in the room.
1621async fn set_room_participant_role(
1622    request: proto::SetRoomParticipantRole,
1623    response: Response<proto::SetRoomParticipantRole>,
1624    session: Session,
1625) -> Result<()> {
1626    let user_id = UserId::from_proto(request.user_id);
1627    let role = ChannelRole::from(request.role());
1628
1629    let (livekit_room, can_publish) = {
1630        let room = session
1631            .db()
1632            .await
1633            .set_room_participant_role(
1634                session.user_id(),
1635                RoomId::from_proto(request.room_id),
1636                user_id,
1637                role,
1638            )
1639            .await?;
1640
1641        let livekit_room = room.livekit_room.clone();
1642        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1643        room_updated(&room, &session.peer);
1644        (livekit_room, can_publish)
1645    };
1646
1647    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1648        live_kit
1649            .update_participant(
1650                livekit_room.clone(),
1651                request.user_id.to_string(),
1652                livekit_api::proto::ParticipantPermission {
1653                    can_subscribe: true,
1654                    can_publish,
1655                    can_publish_data: can_publish,
1656                    hidden: false,
1657                    recorder: false,
1658                },
1659            )
1660            .await
1661            .trace_err();
1662    }
1663
1664    response.send(proto::Ack {})?;
1665    Ok(())
1666}
1667
1668/// Call someone else into the current room
1669async fn call(
1670    request: proto::Call,
1671    response: Response<proto::Call>,
1672    session: Session,
1673) -> Result<()> {
1674    let room_id = RoomId::from_proto(request.room_id);
1675    let calling_user_id = session.user_id();
1676    let calling_connection_id = session.connection_id;
1677    let called_user_id = UserId::from_proto(request.called_user_id);
1678    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1679    if !session
1680        .db()
1681        .await
1682        .has_contact(calling_user_id, called_user_id)
1683        .await?
1684    {
1685        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1686    }
1687
1688    let incoming_call = {
1689        let (room, incoming_call) = &mut *session
1690            .db()
1691            .await
1692            .call(
1693                room_id,
1694                calling_user_id,
1695                calling_connection_id,
1696                called_user_id,
1697                initial_project_id,
1698            )
1699            .await?;
1700        room_updated(room, &session.peer);
1701        mem::take(incoming_call)
1702    };
1703    update_user_contacts(called_user_id, &session).await?;
1704
1705    let mut calls = session
1706        .connection_pool()
1707        .await
1708        .user_connection_ids(called_user_id)
1709        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1710        .collect::<FuturesUnordered<_>>();
1711
1712    while let Some(call_response) = calls.next().await {
1713        match call_response.as_ref() {
1714            Ok(_) => {
1715                response.send(proto::Ack {})?;
1716                return Ok(());
1717            }
1718            Err(_) => {
1719                call_response.trace_err();
1720            }
1721        }
1722    }
1723
1724    {
1725        let room = session
1726            .db()
1727            .await
1728            .call_failed(room_id, called_user_id)
1729            .await?;
1730        room_updated(&room, &session.peer);
1731    }
1732    update_user_contacts(called_user_id, &session).await?;
1733
1734    Err(anyhow!("failed to ring user"))?
1735}
1736
1737/// Cancel an outgoing call.
1738async fn cancel_call(
1739    request: proto::CancelCall,
1740    response: Response<proto::CancelCall>,
1741    session: Session,
1742) -> Result<()> {
1743    let called_user_id = UserId::from_proto(request.called_user_id);
1744    let room_id = RoomId::from_proto(request.room_id);
1745    {
1746        let room = session
1747            .db()
1748            .await
1749            .cancel_call(room_id, session.connection_id, called_user_id)
1750            .await?;
1751        room_updated(&room, &session.peer);
1752    }
1753
1754    for connection_id in session
1755        .connection_pool()
1756        .await
1757        .user_connection_ids(called_user_id)
1758    {
1759        session
1760            .peer
1761            .send(
1762                connection_id,
1763                proto::CallCanceled {
1764                    room_id: room_id.to_proto(),
1765                },
1766            )
1767            .trace_err();
1768    }
1769    response.send(proto::Ack {})?;
1770
1771    update_user_contacts(called_user_id, &session).await?;
1772    Ok(())
1773}
1774
1775/// Decline an incoming call.
1776async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1777    let room_id = RoomId::from_proto(message.room_id);
1778    {
1779        let room = session
1780            .db()
1781            .await
1782            .decline_call(Some(room_id), session.user_id())
1783            .await?
1784            .context("declining call")?;
1785        room_updated(&room, &session.peer);
1786    }
1787
1788    for connection_id in session
1789        .connection_pool()
1790        .await
1791        .user_connection_ids(session.user_id())
1792    {
1793        session
1794            .peer
1795            .send(
1796                connection_id,
1797                proto::CallCanceled {
1798                    room_id: room_id.to_proto(),
1799                },
1800            )
1801            .trace_err();
1802    }
1803    update_user_contacts(session.user_id(), &session).await?;
1804    Ok(())
1805}
1806
1807/// Updates other participants in the room with your current location.
1808async fn update_participant_location(
1809    request: proto::UpdateParticipantLocation,
1810    response: Response<proto::UpdateParticipantLocation>,
1811    session: Session,
1812) -> Result<()> {
1813    let room_id = RoomId::from_proto(request.room_id);
1814    let location = request.location.context("invalid location")?;
1815
1816    let db = session.db().await;
1817    let room = db
1818        .update_room_participant_location(room_id, session.connection_id, location)
1819        .await?;
1820
1821    room_updated(&room, &session.peer);
1822    response.send(proto::Ack {})?;
1823    Ok(())
1824}
1825
1826/// Share a project into the room.
1827async fn share_project(
1828    request: proto::ShareProject,
1829    response: Response<proto::ShareProject>,
1830    session: Session,
1831) -> Result<()> {
1832    let (project_id, room) = &*session
1833        .db()
1834        .await
1835        .share_project(
1836            RoomId::from_proto(request.room_id),
1837            session.connection_id,
1838            &request.worktrees,
1839            request.is_ssh_project,
1840        )
1841        .await?;
1842    response.send(proto::ShareProjectResponse {
1843        project_id: project_id.to_proto(),
1844    })?;
1845    room_updated(room, &session.peer);
1846
1847    Ok(())
1848}
1849
1850/// Unshare a project from the room.
1851async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1852    let project_id = ProjectId::from_proto(message.project_id);
1853    unshare_project_internal(project_id, session.connection_id, &session).await
1854}
1855
1856async fn unshare_project_internal(
1857    project_id: ProjectId,
1858    connection_id: ConnectionId,
1859    session: &Session,
1860) -> Result<()> {
1861    let delete = {
1862        let room_guard = session
1863            .db()
1864            .await
1865            .unshare_project(project_id, connection_id)
1866            .await?;
1867
1868        let (delete, room, guest_connection_ids) = &*room_guard;
1869
1870        let message = proto::UnshareProject {
1871            project_id: project_id.to_proto(),
1872        };
1873
1874        broadcast(
1875            Some(connection_id),
1876            guest_connection_ids.iter().copied(),
1877            |conn_id| session.peer.send(conn_id, message.clone()),
1878        );
1879        if let Some(room) = room {
1880            room_updated(room, &session.peer);
1881        }
1882
1883        *delete
1884    };
1885
1886    if delete {
1887        let db = session.db().await;
1888        db.delete_project(project_id).await?;
1889    }
1890
1891    Ok(())
1892}
1893
1894/// Join someone elses shared project.
1895async fn join_project(
1896    request: proto::JoinProject,
1897    response: Response<proto::JoinProject>,
1898    session: Session,
1899) -> Result<()> {
1900    let project_id = ProjectId::from_proto(request.project_id);
1901
1902    tracing::info!(%project_id, "join project");
1903
1904    let db = session.db().await;
1905    let (project, replica_id) = &mut *db
1906        .join_project(
1907            project_id,
1908            session.connection_id,
1909            session.user_id(),
1910            request.committer_name.clone(),
1911            request.committer_email.clone(),
1912        )
1913        .await?;
1914    drop(db);
1915    tracing::info!(%project_id, "join remote project");
1916    let collaborators = project
1917        .collaborators
1918        .iter()
1919        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1920        .map(|collaborator| collaborator.to_proto())
1921        .collect::<Vec<_>>();
1922    let project_id = project.id;
1923    let guest_user_id = session.user_id();
1924
1925    let worktrees = project
1926        .worktrees
1927        .iter()
1928        .map(|(id, worktree)| proto::WorktreeMetadata {
1929            id: *id,
1930            root_name: worktree.root_name.clone(),
1931            visible: worktree.visible,
1932            abs_path: worktree.abs_path.clone(),
1933        })
1934        .collect::<Vec<_>>();
1935
1936    let add_project_collaborator = proto::AddProjectCollaborator {
1937        project_id: project_id.to_proto(),
1938        collaborator: Some(proto::Collaborator {
1939            peer_id: Some(session.connection_id.into()),
1940            replica_id: replica_id.0 as u32,
1941            user_id: guest_user_id.to_proto(),
1942            is_host: false,
1943            committer_name: request.committer_name.clone(),
1944            committer_email: request.committer_email.clone(),
1945        }),
1946    };
1947
1948    for collaborator in &collaborators {
1949        session
1950            .peer
1951            .send(
1952                collaborator.peer_id.unwrap().into(),
1953                add_project_collaborator.clone(),
1954            )
1955            .trace_err();
1956    }
1957
1958    // First, we send the metadata associated with each worktree.
1959    response.send(proto::JoinProjectResponse {
1960        project_id: project.id.0 as u64,
1961        worktrees: worktrees.clone(),
1962        replica_id: replica_id.0 as u32,
1963        collaborators: collaborators.clone(),
1964        language_servers: project.language_servers.clone(),
1965        role: project.role.into(),
1966    })?;
1967
1968    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1969        // Stream this worktree's entries.
1970        let message = proto::UpdateWorktree {
1971            project_id: project_id.to_proto(),
1972            worktree_id,
1973            abs_path: worktree.abs_path.clone(),
1974            root_name: worktree.root_name,
1975            updated_entries: worktree.entries,
1976            removed_entries: Default::default(),
1977            scan_id: worktree.scan_id,
1978            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1979            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1980            removed_repositories: Default::default(),
1981        };
1982        for update in proto::split_worktree_update(message) {
1983            session.peer.send(session.connection_id, update.clone())?;
1984        }
1985
1986        // Stream this worktree's diagnostics.
1987        for summary in worktree.diagnostic_summaries {
1988            session.peer.send(
1989                session.connection_id,
1990                proto::UpdateDiagnosticSummary {
1991                    project_id: project_id.to_proto(),
1992                    worktree_id: worktree.id,
1993                    summary: Some(summary),
1994                },
1995            )?;
1996        }
1997
1998        for settings_file in worktree.settings_files {
1999            session.peer.send(
2000                session.connection_id,
2001                proto::UpdateWorktreeSettings {
2002                    project_id: project_id.to_proto(),
2003                    worktree_id: worktree.id,
2004                    path: settings_file.path,
2005                    content: Some(settings_file.content),
2006                    kind: Some(settings_file.kind.to_proto() as i32),
2007                },
2008            )?;
2009        }
2010    }
2011
2012    for repository in mem::take(&mut project.repositories) {
2013        for update in split_repository_update(repository) {
2014            session.peer.send(session.connection_id, update)?;
2015        }
2016    }
2017
2018    for language_server in &project.language_servers {
2019        session.peer.send(
2020            session.connection_id,
2021            proto::UpdateLanguageServer {
2022                project_id: project_id.to_proto(),
2023                server_name: Some(language_server.name.clone()),
2024                language_server_id: language_server.id,
2025                variant: Some(
2026                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2027                        proto::LspDiskBasedDiagnosticsUpdated {},
2028                    ),
2029                ),
2030            },
2031        )?;
2032    }
2033
2034    Ok(())
2035}
2036
2037/// Leave someone elses shared project.
2038async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2039    let sender_id = session.connection_id;
2040    let project_id = ProjectId::from_proto(request.project_id);
2041    let db = session.db().await;
2042
2043    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2044    tracing::info!(
2045        %project_id,
2046        "leave project"
2047    );
2048
2049    project_left(project, &session);
2050    if let Some(room) = room {
2051        room_updated(room, &session.peer);
2052    }
2053
2054    Ok(())
2055}
2056
2057/// Updates other participants with changes to the project
2058async fn update_project(
2059    request: proto::UpdateProject,
2060    response: Response<proto::UpdateProject>,
2061    session: Session,
2062) -> Result<()> {
2063    let project_id = ProjectId::from_proto(request.project_id);
2064    let (room, guest_connection_ids) = &*session
2065        .db()
2066        .await
2067        .update_project(project_id, session.connection_id, &request.worktrees)
2068        .await?;
2069    broadcast(
2070        Some(session.connection_id),
2071        guest_connection_ids.iter().copied(),
2072        |connection_id| {
2073            session
2074                .peer
2075                .forward_send(session.connection_id, connection_id, request.clone())
2076        },
2077    );
2078    if let Some(room) = room {
2079        room_updated(room, &session.peer);
2080    }
2081    response.send(proto::Ack {})?;
2082
2083    Ok(())
2084}
2085
2086/// Updates other participants with changes to the worktree
2087async fn update_worktree(
2088    request: proto::UpdateWorktree,
2089    response: Response<proto::UpdateWorktree>,
2090    session: Session,
2091) -> Result<()> {
2092    let guest_connection_ids = session
2093        .db()
2094        .await
2095        .update_worktree(&request, session.connection_id)
2096        .await?;
2097
2098    broadcast(
2099        Some(session.connection_id),
2100        guest_connection_ids.iter().copied(),
2101        |connection_id| {
2102            session
2103                .peer
2104                .forward_send(session.connection_id, connection_id, request.clone())
2105        },
2106    );
2107    response.send(proto::Ack {})?;
2108    Ok(())
2109}
2110
2111async fn update_repository(
2112    request: proto::UpdateRepository,
2113    response: Response<proto::UpdateRepository>,
2114    session: Session,
2115) -> Result<()> {
2116    let guest_connection_ids = session
2117        .db()
2118        .await
2119        .update_repository(&request, session.connection_id)
2120        .await?;
2121
2122    broadcast(
2123        Some(session.connection_id),
2124        guest_connection_ids.iter().copied(),
2125        |connection_id| {
2126            session
2127                .peer
2128                .forward_send(session.connection_id, connection_id, request.clone())
2129        },
2130    );
2131    response.send(proto::Ack {})?;
2132    Ok(())
2133}
2134
2135async fn remove_repository(
2136    request: proto::RemoveRepository,
2137    response: Response<proto::RemoveRepository>,
2138    session: Session,
2139) -> Result<()> {
2140    let guest_connection_ids = session
2141        .db()
2142        .await
2143        .remove_repository(&request, session.connection_id)
2144        .await?;
2145
2146    broadcast(
2147        Some(session.connection_id),
2148        guest_connection_ids.iter().copied(),
2149        |connection_id| {
2150            session
2151                .peer
2152                .forward_send(session.connection_id, connection_id, request.clone())
2153        },
2154    );
2155    response.send(proto::Ack {})?;
2156    Ok(())
2157}
2158
2159/// Updates other participants with changes to the diagnostics
2160async fn update_diagnostic_summary(
2161    message: proto::UpdateDiagnosticSummary,
2162    session: Session,
2163) -> Result<()> {
2164    let guest_connection_ids = session
2165        .db()
2166        .await
2167        .update_diagnostic_summary(&message, session.connection_id)
2168        .await?;
2169
2170    broadcast(
2171        Some(session.connection_id),
2172        guest_connection_ids.iter().copied(),
2173        |connection_id| {
2174            session
2175                .peer
2176                .forward_send(session.connection_id, connection_id, message.clone())
2177        },
2178    );
2179
2180    Ok(())
2181}
2182
2183/// Updates other participants with changes to the worktree settings
2184async fn update_worktree_settings(
2185    message: proto::UpdateWorktreeSettings,
2186    session: Session,
2187) -> Result<()> {
2188    let guest_connection_ids = session
2189        .db()
2190        .await
2191        .update_worktree_settings(&message, session.connection_id)
2192        .await?;
2193
2194    broadcast(
2195        Some(session.connection_id),
2196        guest_connection_ids.iter().copied(),
2197        |connection_id| {
2198            session
2199                .peer
2200                .forward_send(session.connection_id, connection_id, message.clone())
2201        },
2202    );
2203
2204    Ok(())
2205}
2206
2207/// Notify other participants that a language server has started.
2208async fn start_language_server(
2209    request: proto::StartLanguageServer,
2210    session: Session,
2211) -> Result<()> {
2212    let guest_connection_ids = session
2213        .db()
2214        .await
2215        .start_language_server(&request, session.connection_id)
2216        .await?;
2217
2218    broadcast(
2219        Some(session.connection_id),
2220        guest_connection_ids.iter().copied(),
2221        |connection_id| {
2222            session
2223                .peer
2224                .forward_send(session.connection_id, connection_id, request.clone())
2225        },
2226    );
2227    Ok(())
2228}
2229
2230/// Notify other participants that a language server has changed.
2231async fn update_language_server(
2232    request: proto::UpdateLanguageServer,
2233    session: Session,
2234) -> Result<()> {
2235    let project_id = ProjectId::from_proto(request.project_id);
2236    let project_connection_ids = session
2237        .db()
2238        .await
2239        .project_connection_ids(project_id, session.connection_id, true)
2240        .await?;
2241    broadcast(
2242        Some(session.connection_id),
2243        project_connection_ids.iter().copied(),
2244        |connection_id| {
2245            session
2246                .peer
2247                .forward_send(session.connection_id, connection_id, request.clone())
2248        },
2249    );
2250    Ok(())
2251}
2252
2253/// forward a project request to the host. These requests should be read only
2254/// as guests are allowed to send them.
2255async fn forward_read_only_project_request<T>(
2256    request: T,
2257    response: Response<T>,
2258    session: Session,
2259) -> Result<()>
2260where
2261    T: EntityMessage + RequestMessage,
2262{
2263    let project_id = ProjectId::from_proto(request.remote_entity_id());
2264    let host_connection_id = session
2265        .db()
2266        .await
2267        .host_for_read_only_project_request(project_id, session.connection_id)
2268        .await?;
2269    let payload = session
2270        .peer
2271        .forward_request(session.connection_id, host_connection_id, request)
2272        .await?;
2273    response.send(payload)?;
2274    Ok(())
2275}
2276
2277async fn forward_find_search_candidates_request(
2278    request: proto::FindSearchCandidates,
2279    response: Response<proto::FindSearchCandidates>,
2280    session: Session,
2281) -> Result<()> {
2282    let project_id = ProjectId::from_proto(request.remote_entity_id());
2283    let host_connection_id = session
2284        .db()
2285        .await
2286        .host_for_read_only_project_request(project_id, session.connection_id)
2287        .await?;
2288    let payload = session
2289        .peer
2290        .forward_request(session.connection_id, host_connection_id, request)
2291        .await?;
2292    response.send(payload)?;
2293    Ok(())
2294}
2295
2296/// forward a project request to the host. These requests are disallowed
2297/// for guests.
2298async fn forward_mutating_project_request<T>(
2299    request: T,
2300    response: Response<T>,
2301    session: Session,
2302) -> Result<()>
2303where
2304    T: EntityMessage + RequestMessage,
2305{
2306    let project_id = ProjectId::from_proto(request.remote_entity_id());
2307
2308    let host_connection_id = session
2309        .db()
2310        .await
2311        .host_for_mutating_project_request(project_id, session.connection_id)
2312        .await?;
2313    let payload = session
2314        .peer
2315        .forward_request(session.connection_id, host_connection_id, request)
2316        .await?;
2317    response.send(payload)?;
2318    Ok(())
2319}
2320
2321/// Notify other participants that a new buffer has been created
2322async fn create_buffer_for_peer(
2323    request: proto::CreateBufferForPeer,
2324    session: Session,
2325) -> Result<()> {
2326    session
2327        .db()
2328        .await
2329        .check_user_is_project_host(
2330            ProjectId::from_proto(request.project_id),
2331            session.connection_id,
2332        )
2333        .await?;
2334    let peer_id = request.peer_id.context("invalid peer id")?;
2335    session
2336        .peer
2337        .forward_send(session.connection_id, peer_id.into(), request)?;
2338    Ok(())
2339}
2340
2341/// Notify other participants that a buffer has been updated. This is
2342/// allowed for guests as long as the update is limited to selections.
2343async fn update_buffer(
2344    request: proto::UpdateBuffer,
2345    response: Response<proto::UpdateBuffer>,
2346    session: Session,
2347) -> Result<()> {
2348    let project_id = ProjectId::from_proto(request.project_id);
2349    let mut capability = Capability::ReadOnly;
2350
2351    for op in request.operations.iter() {
2352        match op.variant {
2353            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2354            Some(_) => capability = Capability::ReadWrite,
2355        }
2356    }
2357
2358    let host = {
2359        let guard = session
2360            .db()
2361            .await
2362            .connections_for_buffer_update(project_id, session.connection_id, capability)
2363            .await?;
2364
2365        let (host, guests) = &*guard;
2366
2367        broadcast(
2368            Some(session.connection_id),
2369            guests.clone(),
2370            |connection_id| {
2371                session
2372                    .peer
2373                    .forward_send(session.connection_id, connection_id, request.clone())
2374            },
2375        );
2376
2377        *host
2378    };
2379
2380    if host != session.connection_id {
2381        session
2382            .peer
2383            .forward_request(session.connection_id, host, request.clone())
2384            .await?;
2385    }
2386
2387    response.send(proto::Ack {})?;
2388    Ok(())
2389}
2390
2391async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2392    let project_id = ProjectId::from_proto(message.project_id);
2393
2394    let operation = message.operation.as_ref().context("invalid operation")?;
2395    let capability = match operation.variant.as_ref() {
2396        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2397            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2398                match buffer_op.variant {
2399                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2400                        Capability::ReadOnly
2401                    }
2402                    _ => Capability::ReadWrite,
2403                }
2404            } else {
2405                Capability::ReadWrite
2406            }
2407        }
2408        Some(_) => Capability::ReadWrite,
2409        None => Capability::ReadOnly,
2410    };
2411
2412    let guard = session
2413        .db()
2414        .await
2415        .connections_for_buffer_update(project_id, session.connection_id, capability)
2416        .await?;
2417
2418    let (host, guests) = &*guard;
2419
2420    broadcast(
2421        Some(session.connection_id),
2422        guests.iter().chain([host]).copied(),
2423        |connection_id| {
2424            session
2425                .peer
2426                .forward_send(session.connection_id, connection_id, message.clone())
2427        },
2428    );
2429
2430    Ok(())
2431}
2432
2433/// Notify other participants that a project has been updated.
2434async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2435    request: T,
2436    session: Session,
2437) -> Result<()> {
2438    let project_id = ProjectId::from_proto(request.remote_entity_id());
2439    let project_connection_ids = session
2440        .db()
2441        .await
2442        .project_connection_ids(project_id, session.connection_id, false)
2443        .await?;
2444
2445    broadcast(
2446        Some(session.connection_id),
2447        project_connection_ids.iter().copied(),
2448        |connection_id| {
2449            session
2450                .peer
2451                .forward_send(session.connection_id, connection_id, request.clone())
2452        },
2453    );
2454    Ok(())
2455}
2456
2457/// Start following another user in a call.
2458async fn follow(
2459    request: proto::Follow,
2460    response: Response<proto::Follow>,
2461    session: Session,
2462) -> Result<()> {
2463    let room_id = RoomId::from_proto(request.room_id);
2464    let project_id = request.project_id.map(ProjectId::from_proto);
2465    let leader_id = request.leader_id.context("invalid leader id")?.into();
2466    let follower_id = session.connection_id;
2467
2468    session
2469        .db()
2470        .await
2471        .check_room_participants(room_id, leader_id, session.connection_id)
2472        .await?;
2473
2474    let response_payload = session
2475        .peer
2476        .forward_request(session.connection_id, leader_id, request)
2477        .await?;
2478    response.send(response_payload)?;
2479
2480    if let Some(project_id) = project_id {
2481        let room = session
2482            .db()
2483            .await
2484            .follow(room_id, project_id, leader_id, follower_id)
2485            .await?;
2486        room_updated(&room, &session.peer);
2487    }
2488
2489    Ok(())
2490}
2491
2492/// Stop following another user in a call.
2493async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2494    let room_id = RoomId::from_proto(request.room_id);
2495    let project_id = request.project_id.map(ProjectId::from_proto);
2496    let leader_id = request.leader_id.context("invalid leader id")?.into();
2497    let follower_id = session.connection_id;
2498
2499    session
2500        .db()
2501        .await
2502        .check_room_participants(room_id, leader_id, session.connection_id)
2503        .await?;
2504
2505    session
2506        .peer
2507        .forward_send(session.connection_id, leader_id, request)?;
2508
2509    if let Some(project_id) = project_id {
2510        let room = session
2511            .db()
2512            .await
2513            .unfollow(room_id, project_id, leader_id, follower_id)
2514            .await?;
2515        room_updated(&room, &session.peer);
2516    }
2517
2518    Ok(())
2519}
2520
2521/// Notify everyone following you of your current location.
2522async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2523    let room_id = RoomId::from_proto(request.room_id);
2524    let database = session.db.lock().await;
2525
2526    let connection_ids = if let Some(project_id) = request.project_id {
2527        let project_id = ProjectId::from_proto(project_id);
2528        database
2529            .project_connection_ids(project_id, session.connection_id, true)
2530            .await?
2531    } else {
2532        database
2533            .room_connection_ids(room_id, session.connection_id)
2534            .await?
2535    };
2536
2537    // For now, don't send view update messages back to that view's current leader.
2538    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2539        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2540        _ => None,
2541    });
2542
2543    for connection_id in connection_ids.iter().cloned() {
2544        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2545            session
2546                .peer
2547                .forward_send(session.connection_id, connection_id, request.clone())?;
2548        }
2549    }
2550    Ok(())
2551}
2552
2553/// Get public data about users.
2554async fn get_users(
2555    request: proto::GetUsers,
2556    response: Response<proto::GetUsers>,
2557    session: Session,
2558) -> Result<()> {
2559    let user_ids = request
2560        .user_ids
2561        .into_iter()
2562        .map(UserId::from_proto)
2563        .collect();
2564    let users = session
2565        .db()
2566        .await
2567        .get_users_by_ids(user_ids)
2568        .await?
2569        .into_iter()
2570        .map(|user| proto::User {
2571            id: user.id.to_proto(),
2572            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2573            github_login: user.github_login,
2574            name: user.name,
2575        })
2576        .collect();
2577    response.send(proto::UsersResponse { users })?;
2578    Ok(())
2579}
2580
2581/// Search for users (to invite) buy Github login
2582async fn fuzzy_search_users(
2583    request: proto::FuzzySearchUsers,
2584    response: Response<proto::FuzzySearchUsers>,
2585    session: Session,
2586) -> Result<()> {
2587    let query = request.query;
2588    let users = match query.len() {
2589        0 => vec![],
2590        1 | 2 => session
2591            .db()
2592            .await
2593            .get_user_by_github_login(&query)
2594            .await?
2595            .into_iter()
2596            .collect(),
2597        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2598    };
2599    let users = users
2600        .into_iter()
2601        .filter(|user| user.id != session.user_id())
2602        .map(|user| proto::User {
2603            id: user.id.to_proto(),
2604            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2605            github_login: user.github_login,
2606            name: user.name,
2607        })
2608        .collect();
2609    response.send(proto::UsersResponse { users })?;
2610    Ok(())
2611}
2612
2613/// Send a contact request to another user.
2614async fn request_contact(
2615    request: proto::RequestContact,
2616    response: Response<proto::RequestContact>,
2617    session: Session,
2618) -> Result<()> {
2619    let requester_id = session.user_id();
2620    let responder_id = UserId::from_proto(request.responder_id);
2621    if requester_id == responder_id {
2622        return Err(anyhow!("cannot add yourself as a contact"))?;
2623    }
2624
2625    let notifications = session
2626        .db()
2627        .await
2628        .send_contact_request(requester_id, responder_id)
2629        .await?;
2630
2631    // Update outgoing contact requests of requester
2632    let mut update = proto::UpdateContacts::default();
2633    update.outgoing_requests.push(responder_id.to_proto());
2634    for connection_id in session
2635        .connection_pool()
2636        .await
2637        .user_connection_ids(requester_id)
2638    {
2639        session.peer.send(connection_id, update.clone())?;
2640    }
2641
2642    // Update incoming contact requests of responder
2643    let mut update = proto::UpdateContacts::default();
2644    update
2645        .incoming_requests
2646        .push(proto::IncomingContactRequest {
2647            requester_id: requester_id.to_proto(),
2648        });
2649    let connection_pool = session.connection_pool().await;
2650    for connection_id in connection_pool.user_connection_ids(responder_id) {
2651        session.peer.send(connection_id, update.clone())?;
2652    }
2653
2654    send_notifications(&connection_pool, &session.peer, notifications);
2655
2656    response.send(proto::Ack {})?;
2657    Ok(())
2658}
2659
2660/// Accept or decline a contact request
2661async fn respond_to_contact_request(
2662    request: proto::RespondToContactRequest,
2663    response: Response<proto::RespondToContactRequest>,
2664    session: Session,
2665) -> Result<()> {
2666    let responder_id = session.user_id();
2667    let requester_id = UserId::from_proto(request.requester_id);
2668    let db = session.db().await;
2669    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2670        db.dismiss_contact_notification(responder_id, requester_id)
2671            .await?;
2672    } else {
2673        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2674
2675        let notifications = db
2676            .respond_to_contact_request(responder_id, requester_id, accept)
2677            .await?;
2678        let requester_busy = db.is_user_busy(requester_id).await?;
2679        let responder_busy = db.is_user_busy(responder_id).await?;
2680
2681        let pool = session.connection_pool().await;
2682        // Update responder with new contact
2683        let mut update = proto::UpdateContacts::default();
2684        if accept {
2685            update
2686                .contacts
2687                .push(contact_for_user(requester_id, requester_busy, &pool));
2688        }
2689        update
2690            .remove_incoming_requests
2691            .push(requester_id.to_proto());
2692        for connection_id in pool.user_connection_ids(responder_id) {
2693            session.peer.send(connection_id, update.clone())?;
2694        }
2695
2696        // Update requester with new contact
2697        let mut update = proto::UpdateContacts::default();
2698        if accept {
2699            update
2700                .contacts
2701                .push(contact_for_user(responder_id, responder_busy, &pool));
2702        }
2703        update
2704            .remove_outgoing_requests
2705            .push(responder_id.to_proto());
2706
2707        for connection_id in pool.user_connection_ids(requester_id) {
2708            session.peer.send(connection_id, update.clone())?;
2709        }
2710
2711        send_notifications(&pool, &session.peer, notifications);
2712    }
2713
2714    response.send(proto::Ack {})?;
2715    Ok(())
2716}
2717
2718/// Remove a contact.
2719async fn remove_contact(
2720    request: proto::RemoveContact,
2721    response: Response<proto::RemoveContact>,
2722    session: Session,
2723) -> Result<()> {
2724    let requester_id = session.user_id();
2725    let responder_id = UserId::from_proto(request.user_id);
2726    let db = session.db().await;
2727    let (contact_accepted, deleted_notification_id) =
2728        db.remove_contact(requester_id, responder_id).await?;
2729
2730    let pool = session.connection_pool().await;
2731    // Update outgoing contact requests of requester
2732    let mut update = proto::UpdateContacts::default();
2733    if contact_accepted {
2734        update.remove_contacts.push(responder_id.to_proto());
2735    } else {
2736        update
2737            .remove_outgoing_requests
2738            .push(responder_id.to_proto());
2739    }
2740    for connection_id in pool.user_connection_ids(requester_id) {
2741        session.peer.send(connection_id, update.clone())?;
2742    }
2743
2744    // Update incoming contact requests of responder
2745    let mut update = proto::UpdateContacts::default();
2746    if contact_accepted {
2747        update.remove_contacts.push(requester_id.to_proto());
2748    } else {
2749        update
2750            .remove_incoming_requests
2751            .push(requester_id.to_proto());
2752    }
2753    for connection_id in pool.user_connection_ids(responder_id) {
2754        session.peer.send(connection_id, update.clone())?;
2755        if let Some(notification_id) = deleted_notification_id {
2756            session.peer.send(
2757                connection_id,
2758                proto::DeleteNotification {
2759                    notification_id: notification_id.to_proto(),
2760                },
2761            )?;
2762        }
2763    }
2764
2765    response.send(proto::Ack {})?;
2766    Ok(())
2767}
2768
2769fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2770    version.0.minor() < 139
2771}
2772
2773async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2774    if is_staff {
2775        return Ok(proto::Plan::ZedPro);
2776    }
2777
2778    let subscription = db.get_active_billing_subscription(user_id).await?;
2779    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2780
2781    let plan = if let Some(subscription_kind) = subscription_kind {
2782        match subscription_kind {
2783            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2784            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2785            SubscriptionKind::ZedFree => proto::Plan::Free,
2786        }
2787    } else {
2788        proto::Plan::Free
2789    };
2790
2791    Ok(plan)
2792}
2793
2794async fn make_update_user_plan_message(
2795    user: &User,
2796    is_staff: bool,
2797    db: &Arc<Database>,
2798    llm_db: Option<Arc<LlmDatabase>>,
2799) -> Result<proto::UpdateUserPlan> {
2800    let feature_flags = db.get_user_flags(user.id).await?;
2801    let plan = current_plan(db, user.id, is_staff).await?;
2802    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2803    let billing_preferences = db.get_billing_preferences(user.id).await?;
2804
2805    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2806        let subscription = db.get_active_billing_subscription(user.id).await?;
2807
2808        let subscription_period =
2809            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2810
2811        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2812            llm_db
2813                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2814                .await?
2815        } else {
2816            None
2817        };
2818
2819        (subscription_period, usage)
2820    } else {
2821        (None, None)
2822    };
2823
2824    let bypass_account_age_check = feature_flags
2825        .iter()
2826        .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2827    let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2828        && !bypass_account_age_check
2829        && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2830
2831    Ok(proto::UpdateUserPlan {
2832        plan: plan.into(),
2833        trial_started_at: billing_customer
2834            .as_ref()
2835            .and_then(|billing_customer| billing_customer.trial_started_at)
2836            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2837        is_usage_based_billing_enabled: if is_staff {
2838            Some(true)
2839        } else {
2840            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2841        },
2842        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2843            proto::SubscriptionPeriod {
2844                started_at: started_at.timestamp() as u64,
2845                ended_at: ended_at.timestamp() as u64,
2846            }
2847        }),
2848        account_too_young: Some(account_too_young),
2849        has_overdue_invoices: billing_customer
2850            .map(|billing_customer| billing_customer.has_overdue_invoices),
2851        usage: Some(
2852            usage
2853                .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2854                .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2855        ),
2856    })
2857}
2858
2859fn model_requests_limit(
2860    plan: zed_llm_client::Plan,
2861    feature_flags: &Vec<String>,
2862) -> zed_llm_client::UsageLimit {
2863    match plan.model_requests_limit() {
2864        zed_llm_client::UsageLimit::Limited(limit) => {
2865            let limit = if plan == zed_llm_client::Plan::ZedProTrial
2866                && feature_flags
2867                    .iter()
2868                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2869            {
2870                1_000
2871            } else {
2872                limit
2873            };
2874
2875            zed_llm_client::UsageLimit::Limited(limit)
2876        }
2877        zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2878    }
2879}
2880
2881fn subscription_usage_to_proto(
2882    plan: proto::Plan,
2883    usage: crate::llm::db::subscription_usage::Model,
2884    feature_flags: &Vec<String>,
2885) -> proto::SubscriptionUsage {
2886    let plan = match plan {
2887        proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2888        proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2889        proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2890    };
2891
2892    proto::SubscriptionUsage {
2893        model_requests_usage_amount: usage.model_requests as u32,
2894        model_requests_usage_limit: Some(proto::UsageLimit {
2895            variant: Some(match model_requests_limit(plan, feature_flags) {
2896                zed_llm_client::UsageLimit::Limited(limit) => {
2897                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2898                        limit: limit as u32,
2899                    })
2900                }
2901                zed_llm_client::UsageLimit::Unlimited => {
2902                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2903                }
2904            }),
2905        }),
2906        edit_predictions_usage_amount: usage.edit_predictions as u32,
2907        edit_predictions_usage_limit: Some(proto::UsageLimit {
2908            variant: Some(match plan.edit_predictions_limit() {
2909                zed_llm_client::UsageLimit::Limited(limit) => {
2910                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2911                        limit: limit as u32,
2912                    })
2913                }
2914                zed_llm_client::UsageLimit::Unlimited => {
2915                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2916                }
2917            }),
2918        }),
2919    }
2920}
2921
2922fn make_default_subscription_usage(
2923    plan: proto::Plan,
2924    feature_flags: &Vec<String>,
2925) -> proto::SubscriptionUsage {
2926    let plan = match plan {
2927        proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2928        proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2929        proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2930    };
2931
2932    proto::SubscriptionUsage {
2933        model_requests_usage_amount: 0,
2934        model_requests_usage_limit: Some(proto::UsageLimit {
2935            variant: Some(match model_requests_limit(plan, feature_flags) {
2936                zed_llm_client::UsageLimit::Limited(limit) => {
2937                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2938                        limit: limit as u32,
2939                    })
2940                }
2941                zed_llm_client::UsageLimit::Unlimited => {
2942                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2943                }
2944            }),
2945        }),
2946        edit_predictions_usage_amount: 0,
2947        edit_predictions_usage_limit: Some(proto::UsageLimit {
2948            variant: Some(match plan.edit_predictions_limit() {
2949                zed_llm_client::UsageLimit::Limited(limit) => {
2950                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2951                        limit: limit as u32,
2952                    })
2953                }
2954                zed_llm_client::UsageLimit::Unlimited => {
2955                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2956                }
2957            }),
2958        }),
2959    }
2960}
2961
2962async fn update_user_plan(session: &Session) -> Result<()> {
2963    let db = session.db().await;
2964
2965    let update_user_plan = make_update_user_plan_message(
2966        session.principal.user(),
2967        session.is_staff(),
2968        &db.0,
2969        session.app_state.llm_db.clone(),
2970    )
2971    .await?;
2972
2973    session
2974        .peer
2975        .send(session.connection_id, update_user_plan)
2976        .trace_err();
2977
2978    Ok(())
2979}
2980
2981async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2982    subscribe_user_to_channels(session.user_id(), &session).await?;
2983    Ok(())
2984}
2985
2986async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2987    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2988    let mut pool = session.connection_pool().await;
2989    for membership in &channels_for_user.channel_memberships {
2990        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2991    }
2992    session.peer.send(
2993        session.connection_id,
2994        build_update_user_channels(&channels_for_user),
2995    )?;
2996    session.peer.send(
2997        session.connection_id,
2998        build_channels_update(channels_for_user),
2999    )?;
3000    Ok(())
3001}
3002
3003/// Creates a new channel.
3004async fn create_channel(
3005    request: proto::CreateChannel,
3006    response: Response<proto::CreateChannel>,
3007    session: Session,
3008) -> Result<()> {
3009    let db = session.db().await;
3010
3011    let parent_id = request.parent_id.map(ChannelId::from_proto);
3012    let (channel, membership) = db
3013        .create_channel(&request.name, parent_id, session.user_id())
3014        .await?;
3015
3016    let root_id = channel.root_id();
3017    let channel = Channel::from_model(channel);
3018
3019    response.send(proto::CreateChannelResponse {
3020        channel: Some(channel.to_proto()),
3021        parent_id: request.parent_id,
3022    })?;
3023
3024    let mut connection_pool = session.connection_pool().await;
3025    if let Some(membership) = membership {
3026        connection_pool.subscribe_to_channel(
3027            membership.user_id,
3028            membership.channel_id,
3029            membership.role,
3030        );
3031        let update = proto::UpdateUserChannels {
3032            channel_memberships: vec![proto::ChannelMembership {
3033                channel_id: membership.channel_id.to_proto(),
3034                role: membership.role.into(),
3035            }],
3036            ..Default::default()
3037        };
3038        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3039            session.peer.send(connection_id, update.clone())?;
3040        }
3041    }
3042
3043    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3044        if !role.can_see_channel(channel.visibility) {
3045            continue;
3046        }
3047
3048        let update = proto::UpdateChannels {
3049            channels: vec![channel.to_proto()],
3050            ..Default::default()
3051        };
3052        session.peer.send(connection_id, update.clone())?;
3053    }
3054
3055    Ok(())
3056}
3057
3058/// Delete a channel
3059async fn delete_channel(
3060    request: proto::DeleteChannel,
3061    response: Response<proto::DeleteChannel>,
3062    session: Session,
3063) -> Result<()> {
3064    let db = session.db().await;
3065
3066    let channel_id = request.channel_id;
3067    let (root_channel, removed_channels) = db
3068        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3069        .await?;
3070    response.send(proto::Ack {})?;
3071
3072    // Notify members of removed channels
3073    let mut update = proto::UpdateChannels::default();
3074    update
3075        .delete_channels
3076        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3077
3078    let connection_pool = session.connection_pool().await;
3079    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3080        session.peer.send(connection_id, update.clone())?;
3081    }
3082
3083    Ok(())
3084}
3085
3086/// Invite someone to join a channel.
3087async fn invite_channel_member(
3088    request: proto::InviteChannelMember,
3089    response: Response<proto::InviteChannelMember>,
3090    session: Session,
3091) -> Result<()> {
3092    let db = session.db().await;
3093    let channel_id = ChannelId::from_proto(request.channel_id);
3094    let invitee_id = UserId::from_proto(request.user_id);
3095    let InviteMemberResult {
3096        channel,
3097        notifications,
3098    } = db
3099        .invite_channel_member(
3100            channel_id,
3101            invitee_id,
3102            session.user_id(),
3103            request.role().into(),
3104        )
3105        .await?;
3106
3107    let update = proto::UpdateChannels {
3108        channel_invitations: vec![channel.to_proto()],
3109        ..Default::default()
3110    };
3111
3112    let connection_pool = session.connection_pool().await;
3113    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3114        session.peer.send(connection_id, update.clone())?;
3115    }
3116
3117    send_notifications(&connection_pool, &session.peer, notifications);
3118
3119    response.send(proto::Ack {})?;
3120    Ok(())
3121}
3122
3123/// remove someone from a channel
3124async fn remove_channel_member(
3125    request: proto::RemoveChannelMember,
3126    response: Response<proto::RemoveChannelMember>,
3127    session: Session,
3128) -> Result<()> {
3129    let db = session.db().await;
3130    let channel_id = ChannelId::from_proto(request.channel_id);
3131    let member_id = UserId::from_proto(request.user_id);
3132
3133    let RemoveChannelMemberResult {
3134        membership_update,
3135        notification_id,
3136    } = db
3137        .remove_channel_member(channel_id, member_id, session.user_id())
3138        .await?;
3139
3140    let mut connection_pool = session.connection_pool().await;
3141    notify_membership_updated(
3142        &mut connection_pool,
3143        membership_update,
3144        member_id,
3145        &session.peer,
3146    );
3147    for connection_id in connection_pool.user_connection_ids(member_id) {
3148        if let Some(notification_id) = notification_id {
3149            session
3150                .peer
3151                .send(
3152                    connection_id,
3153                    proto::DeleteNotification {
3154                        notification_id: notification_id.to_proto(),
3155                    },
3156                )
3157                .trace_err();
3158        }
3159    }
3160
3161    response.send(proto::Ack {})?;
3162    Ok(())
3163}
3164
3165/// Toggle the channel between public and private.
3166/// Care is taken to maintain the invariant that public channels only descend from public channels,
3167/// (though members-only channels can appear at any point in the hierarchy).
3168async fn set_channel_visibility(
3169    request: proto::SetChannelVisibility,
3170    response: Response<proto::SetChannelVisibility>,
3171    session: Session,
3172) -> Result<()> {
3173    let db = session.db().await;
3174    let channel_id = ChannelId::from_proto(request.channel_id);
3175    let visibility = request.visibility().into();
3176
3177    let channel_model = db
3178        .set_channel_visibility(channel_id, visibility, session.user_id())
3179        .await?;
3180    let root_id = channel_model.root_id();
3181    let channel = Channel::from_model(channel_model);
3182
3183    let mut connection_pool = session.connection_pool().await;
3184    for (user_id, role) in connection_pool
3185        .channel_user_ids(root_id)
3186        .collect::<Vec<_>>()
3187        .into_iter()
3188    {
3189        let update = if role.can_see_channel(channel.visibility) {
3190            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3191            proto::UpdateChannels {
3192                channels: vec![channel.to_proto()],
3193                ..Default::default()
3194            }
3195        } else {
3196            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3197            proto::UpdateChannels {
3198                delete_channels: vec![channel.id.to_proto()],
3199                ..Default::default()
3200            }
3201        };
3202
3203        for connection_id in connection_pool.user_connection_ids(user_id) {
3204            session.peer.send(connection_id, update.clone())?;
3205        }
3206    }
3207
3208    response.send(proto::Ack {})?;
3209    Ok(())
3210}
3211
3212/// Alter the role for a user in the channel.
3213async fn set_channel_member_role(
3214    request: proto::SetChannelMemberRole,
3215    response: Response<proto::SetChannelMemberRole>,
3216    session: Session,
3217) -> Result<()> {
3218    let db = session.db().await;
3219    let channel_id = ChannelId::from_proto(request.channel_id);
3220    let member_id = UserId::from_proto(request.user_id);
3221    let result = db
3222        .set_channel_member_role(
3223            channel_id,
3224            session.user_id(),
3225            member_id,
3226            request.role().into(),
3227        )
3228        .await?;
3229
3230    match result {
3231        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3232            let mut connection_pool = session.connection_pool().await;
3233            notify_membership_updated(
3234                &mut connection_pool,
3235                membership_update,
3236                member_id,
3237                &session.peer,
3238            )
3239        }
3240        db::SetMemberRoleResult::InviteUpdated(channel) => {
3241            let update = proto::UpdateChannels {
3242                channel_invitations: vec![channel.to_proto()],
3243                ..Default::default()
3244            };
3245
3246            for connection_id in session
3247                .connection_pool()
3248                .await
3249                .user_connection_ids(member_id)
3250            {
3251                session.peer.send(connection_id, update.clone())?;
3252            }
3253        }
3254    }
3255
3256    response.send(proto::Ack {})?;
3257    Ok(())
3258}
3259
3260/// Change the name of a channel
3261async fn rename_channel(
3262    request: proto::RenameChannel,
3263    response: Response<proto::RenameChannel>,
3264    session: Session,
3265) -> Result<()> {
3266    let db = session.db().await;
3267    let channel_id = ChannelId::from_proto(request.channel_id);
3268    let channel_model = db
3269        .rename_channel(channel_id, session.user_id(), &request.name)
3270        .await?;
3271    let root_id = channel_model.root_id();
3272    let channel = Channel::from_model(channel_model);
3273
3274    response.send(proto::RenameChannelResponse {
3275        channel: Some(channel.to_proto()),
3276    })?;
3277
3278    let connection_pool = session.connection_pool().await;
3279    let update = proto::UpdateChannels {
3280        channels: vec![channel.to_proto()],
3281        ..Default::default()
3282    };
3283    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3284        if role.can_see_channel(channel.visibility) {
3285            session.peer.send(connection_id, update.clone())?;
3286        }
3287    }
3288
3289    Ok(())
3290}
3291
3292/// Move a channel to a new parent.
3293async fn move_channel(
3294    request: proto::MoveChannel,
3295    response: Response<proto::MoveChannel>,
3296    session: Session,
3297) -> Result<()> {
3298    let channel_id = ChannelId::from_proto(request.channel_id);
3299    let to = ChannelId::from_proto(request.to);
3300
3301    let (root_id, channels) = session
3302        .db()
3303        .await
3304        .move_channel(channel_id, to, session.user_id())
3305        .await?;
3306
3307    let connection_pool = session.connection_pool().await;
3308    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3309        let channels = channels
3310            .iter()
3311            .filter_map(|channel| {
3312                if role.can_see_channel(channel.visibility) {
3313                    Some(channel.to_proto())
3314                } else {
3315                    None
3316                }
3317            })
3318            .collect::<Vec<_>>();
3319        if channels.is_empty() {
3320            continue;
3321        }
3322
3323        let update = proto::UpdateChannels {
3324            channels,
3325            ..Default::default()
3326        };
3327
3328        session.peer.send(connection_id, update.clone())?;
3329    }
3330
3331    response.send(Ack {})?;
3332    Ok(())
3333}
3334
3335async fn reorder_channel(
3336    request: proto::ReorderChannel,
3337    response: Response<proto::ReorderChannel>,
3338    session: Session,
3339) -> Result<()> {
3340    let channel_id = ChannelId::from_proto(request.channel_id);
3341    let direction = request.direction();
3342
3343    let updated_channels = session
3344        .db()
3345        .await
3346        .reorder_channel(channel_id, direction, session.user_id())
3347        .await?;
3348
3349    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3350        let connection_pool = session.connection_pool().await;
3351        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3352            let channels = updated_channels
3353                .iter()
3354                .filter_map(|channel| {
3355                    if role.can_see_channel(channel.visibility) {
3356                        Some(channel.to_proto())
3357                    } else {
3358                        None
3359                    }
3360                })
3361                .collect::<Vec<_>>();
3362
3363            if channels.is_empty() {
3364                continue;
3365            }
3366
3367            let update = proto::UpdateChannels {
3368                channels,
3369                ..Default::default()
3370            };
3371
3372            session.peer.send(connection_id, update.clone())?;
3373        }
3374    }
3375
3376    response.send(Ack {})?;
3377    Ok(())
3378}
3379
3380/// Get the list of channel members
3381async fn get_channel_members(
3382    request: proto::GetChannelMembers,
3383    response: Response<proto::GetChannelMembers>,
3384    session: Session,
3385) -> Result<()> {
3386    let db = session.db().await;
3387    let channel_id = ChannelId::from_proto(request.channel_id);
3388    let limit = if request.limit == 0 {
3389        u16::MAX as u64
3390    } else {
3391        request.limit
3392    };
3393    let (members, users) = db
3394        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3395        .await?;
3396    response.send(proto::GetChannelMembersResponse { members, users })?;
3397    Ok(())
3398}
3399
3400/// Accept or decline a channel invitation.
3401async fn respond_to_channel_invite(
3402    request: proto::RespondToChannelInvite,
3403    response: Response<proto::RespondToChannelInvite>,
3404    session: Session,
3405) -> Result<()> {
3406    let db = session.db().await;
3407    let channel_id = ChannelId::from_proto(request.channel_id);
3408    let RespondToChannelInvite {
3409        membership_update,
3410        notifications,
3411    } = db
3412        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3413        .await?;
3414
3415    let mut connection_pool = session.connection_pool().await;
3416    if let Some(membership_update) = membership_update {
3417        notify_membership_updated(
3418            &mut connection_pool,
3419            membership_update,
3420            session.user_id(),
3421            &session.peer,
3422        );
3423    } else {
3424        let update = proto::UpdateChannels {
3425            remove_channel_invitations: vec![channel_id.to_proto()],
3426            ..Default::default()
3427        };
3428
3429        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3430            session.peer.send(connection_id, update.clone())?;
3431        }
3432    };
3433
3434    send_notifications(&connection_pool, &session.peer, notifications);
3435
3436    response.send(proto::Ack {})?;
3437
3438    Ok(())
3439}
3440
3441/// Join the channels' room
3442async fn join_channel(
3443    request: proto::JoinChannel,
3444    response: Response<proto::JoinChannel>,
3445    session: Session,
3446) -> Result<()> {
3447    let channel_id = ChannelId::from_proto(request.channel_id);
3448    join_channel_internal(channel_id, Box::new(response), session).await
3449}
3450
3451trait JoinChannelInternalResponse {
3452    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3453}
3454impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3455    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3456        Response::<proto::JoinChannel>::send(self, result)
3457    }
3458}
3459impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3460    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3461        Response::<proto::JoinRoom>::send(self, result)
3462    }
3463}
3464
3465async fn join_channel_internal(
3466    channel_id: ChannelId,
3467    response: Box<impl JoinChannelInternalResponse>,
3468    session: Session,
3469) -> Result<()> {
3470    let joined_room = {
3471        let mut db = session.db().await;
3472        // If zed quits without leaving the room, and the user re-opens zed before the
3473        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3474        // room they were in.
3475        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3476            tracing::info!(
3477                stale_connection_id = %connection,
3478                "cleaning up stale connection",
3479            );
3480            drop(db);
3481            leave_room_for_session(&session, connection).await?;
3482            db = session.db().await;
3483        }
3484
3485        let (joined_room, membership_updated, role) = db
3486            .join_channel(channel_id, session.user_id(), session.connection_id)
3487            .await?;
3488
3489        let live_kit_connection_info =
3490            session
3491                .app_state
3492                .livekit_client
3493                .as_ref()
3494                .and_then(|live_kit| {
3495                    let (can_publish, token) = if role == ChannelRole::Guest {
3496                        (
3497                            false,
3498                            live_kit
3499                                .guest_token(
3500                                    &joined_room.room.livekit_room,
3501                                    &session.user_id().to_string(),
3502                                )
3503                                .trace_err()?,
3504                        )
3505                    } else {
3506                        (
3507                            true,
3508                            live_kit
3509                                .room_token(
3510                                    &joined_room.room.livekit_room,
3511                                    &session.user_id().to_string(),
3512                                )
3513                                .trace_err()?,
3514                        )
3515                    };
3516
3517                    Some(LiveKitConnectionInfo {
3518                        server_url: live_kit.url().into(),
3519                        token,
3520                        can_publish,
3521                    })
3522                });
3523
3524        response.send(proto::JoinRoomResponse {
3525            room: Some(joined_room.room.clone()),
3526            channel_id: joined_room
3527                .channel
3528                .as_ref()
3529                .map(|channel| channel.id.to_proto()),
3530            live_kit_connection_info,
3531        })?;
3532
3533        let mut connection_pool = session.connection_pool().await;
3534        if let Some(membership_updated) = membership_updated {
3535            notify_membership_updated(
3536                &mut connection_pool,
3537                membership_updated,
3538                session.user_id(),
3539                &session.peer,
3540            );
3541        }
3542
3543        room_updated(&joined_room.room, &session.peer);
3544
3545        joined_room
3546    };
3547
3548    channel_updated(
3549        &joined_room.channel.context("channel not returned")?,
3550        &joined_room.room,
3551        &session.peer,
3552        &*session.connection_pool().await,
3553    );
3554
3555    update_user_contacts(session.user_id(), &session).await?;
3556    Ok(())
3557}
3558
3559/// Start editing the channel notes
3560async fn join_channel_buffer(
3561    request: proto::JoinChannelBuffer,
3562    response: Response<proto::JoinChannelBuffer>,
3563    session: Session,
3564) -> Result<()> {
3565    let db = session.db().await;
3566    let channel_id = ChannelId::from_proto(request.channel_id);
3567
3568    let open_response = db
3569        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3570        .await?;
3571
3572    let collaborators = open_response.collaborators.clone();
3573    response.send(open_response)?;
3574
3575    let update = UpdateChannelBufferCollaborators {
3576        channel_id: channel_id.to_proto(),
3577        collaborators: collaborators.clone(),
3578    };
3579    channel_buffer_updated(
3580        session.connection_id,
3581        collaborators
3582            .iter()
3583            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3584        &update,
3585        &session.peer,
3586    );
3587
3588    Ok(())
3589}
3590
3591/// Edit the channel notes
3592async fn update_channel_buffer(
3593    request: proto::UpdateChannelBuffer,
3594    session: Session,
3595) -> Result<()> {
3596    let db = session.db().await;
3597    let channel_id = ChannelId::from_proto(request.channel_id);
3598
3599    let (collaborators, epoch, version) = db
3600        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3601        .await?;
3602
3603    channel_buffer_updated(
3604        session.connection_id,
3605        collaborators.clone(),
3606        &proto::UpdateChannelBuffer {
3607            channel_id: channel_id.to_proto(),
3608            operations: request.operations,
3609        },
3610        &session.peer,
3611    );
3612
3613    let pool = &*session.connection_pool().await;
3614
3615    let non_collaborators =
3616        pool.channel_connection_ids(channel_id)
3617            .filter_map(|(connection_id, _)| {
3618                if collaborators.contains(&connection_id) {
3619                    None
3620                } else {
3621                    Some(connection_id)
3622                }
3623            });
3624
3625    broadcast(None, non_collaborators, |peer_id| {
3626        session.peer.send(
3627            peer_id,
3628            proto::UpdateChannels {
3629                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3630                    channel_id: channel_id.to_proto(),
3631                    epoch: epoch as u64,
3632                    version: version.clone(),
3633                }],
3634                ..Default::default()
3635            },
3636        )
3637    });
3638
3639    Ok(())
3640}
3641
3642/// Rejoin the channel notes after a connection blip
3643async fn rejoin_channel_buffers(
3644    request: proto::RejoinChannelBuffers,
3645    response: Response<proto::RejoinChannelBuffers>,
3646    session: Session,
3647) -> Result<()> {
3648    let db = session.db().await;
3649    let buffers = db
3650        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3651        .await?;
3652
3653    for rejoined_buffer in &buffers {
3654        let collaborators_to_notify = rejoined_buffer
3655            .buffer
3656            .collaborators
3657            .iter()
3658            .filter_map(|c| Some(c.peer_id?.into()));
3659        channel_buffer_updated(
3660            session.connection_id,
3661            collaborators_to_notify,
3662            &proto::UpdateChannelBufferCollaborators {
3663                channel_id: rejoined_buffer.buffer.channel_id,
3664                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3665            },
3666            &session.peer,
3667        );
3668    }
3669
3670    response.send(proto::RejoinChannelBuffersResponse {
3671        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3672    })?;
3673
3674    Ok(())
3675}
3676
3677/// Stop editing the channel notes
3678async fn leave_channel_buffer(
3679    request: proto::LeaveChannelBuffer,
3680    response: Response<proto::LeaveChannelBuffer>,
3681    session: Session,
3682) -> Result<()> {
3683    let db = session.db().await;
3684    let channel_id = ChannelId::from_proto(request.channel_id);
3685
3686    let left_buffer = db
3687        .leave_channel_buffer(channel_id, session.connection_id)
3688        .await?;
3689
3690    response.send(Ack {})?;
3691
3692    channel_buffer_updated(
3693        session.connection_id,
3694        left_buffer.connections,
3695        &proto::UpdateChannelBufferCollaborators {
3696            channel_id: channel_id.to_proto(),
3697            collaborators: left_buffer.collaborators,
3698        },
3699        &session.peer,
3700    );
3701
3702    Ok(())
3703}
3704
3705fn channel_buffer_updated<T: EnvelopedMessage>(
3706    sender_id: ConnectionId,
3707    collaborators: impl IntoIterator<Item = ConnectionId>,
3708    message: &T,
3709    peer: &Peer,
3710) {
3711    broadcast(Some(sender_id), collaborators, |peer_id| {
3712        peer.send(peer_id, message.clone())
3713    });
3714}
3715
3716fn send_notifications(
3717    connection_pool: &ConnectionPool,
3718    peer: &Peer,
3719    notifications: db::NotificationBatch,
3720) {
3721    for (user_id, notification) in notifications {
3722        for connection_id in connection_pool.user_connection_ids(user_id) {
3723            if let Err(error) = peer.send(
3724                connection_id,
3725                proto::AddNotification {
3726                    notification: Some(notification.clone()),
3727                },
3728            ) {
3729                tracing::error!(
3730                    "failed to send notification to {:?} {}",
3731                    connection_id,
3732                    error
3733                );
3734            }
3735        }
3736    }
3737}
3738
3739/// Send a message to the channel
3740async fn send_channel_message(
3741    request: proto::SendChannelMessage,
3742    response: Response<proto::SendChannelMessage>,
3743    session: Session,
3744) -> Result<()> {
3745    // Validate the message body.
3746    let body = request.body.trim().to_string();
3747    if body.len() > MAX_MESSAGE_LEN {
3748        return Err(anyhow!("message is too long"))?;
3749    }
3750    if body.is_empty() {
3751        return Err(anyhow!("message can't be blank"))?;
3752    }
3753
3754    // TODO: adjust mentions if body is trimmed
3755
3756    let timestamp = OffsetDateTime::now_utc();
3757    let nonce = request.nonce.context("nonce can't be blank")?;
3758
3759    let channel_id = ChannelId::from_proto(request.channel_id);
3760    let CreatedChannelMessage {
3761        message_id,
3762        participant_connection_ids,
3763        notifications,
3764    } = session
3765        .db()
3766        .await
3767        .create_channel_message(
3768            channel_id,
3769            session.user_id(),
3770            &body,
3771            &request.mentions,
3772            timestamp,
3773            nonce.clone().into(),
3774            request.reply_to_message_id.map(MessageId::from_proto),
3775        )
3776        .await?;
3777
3778    let message = proto::ChannelMessage {
3779        sender_id: session.user_id().to_proto(),
3780        id: message_id.to_proto(),
3781        body,
3782        mentions: request.mentions,
3783        timestamp: timestamp.unix_timestamp() as u64,
3784        nonce: Some(nonce),
3785        reply_to_message_id: request.reply_to_message_id,
3786        edited_at: None,
3787    };
3788    broadcast(
3789        Some(session.connection_id),
3790        participant_connection_ids.clone(),
3791        |connection| {
3792            session.peer.send(
3793                connection,
3794                proto::ChannelMessageSent {
3795                    channel_id: channel_id.to_proto(),
3796                    message: Some(message.clone()),
3797                },
3798            )
3799        },
3800    );
3801    response.send(proto::SendChannelMessageResponse {
3802        message: Some(message),
3803    })?;
3804
3805    let pool = &*session.connection_pool().await;
3806    let non_participants =
3807        pool.channel_connection_ids(channel_id)
3808            .filter_map(|(connection_id, _)| {
3809                if participant_connection_ids.contains(&connection_id) {
3810                    None
3811                } else {
3812                    Some(connection_id)
3813                }
3814            });
3815    broadcast(None, non_participants, |peer_id| {
3816        session.peer.send(
3817            peer_id,
3818            proto::UpdateChannels {
3819                latest_channel_message_ids: vec![proto::ChannelMessageId {
3820                    channel_id: channel_id.to_proto(),
3821                    message_id: message_id.to_proto(),
3822                }],
3823                ..Default::default()
3824            },
3825        )
3826    });
3827    send_notifications(pool, &session.peer, notifications);
3828
3829    Ok(())
3830}
3831
3832/// Delete a channel message
3833async fn remove_channel_message(
3834    request: proto::RemoveChannelMessage,
3835    response: Response<proto::RemoveChannelMessage>,
3836    session: Session,
3837) -> Result<()> {
3838    let channel_id = ChannelId::from_proto(request.channel_id);
3839    let message_id = MessageId::from_proto(request.message_id);
3840    let (connection_ids, existing_notification_ids) = session
3841        .db()
3842        .await
3843        .remove_channel_message(channel_id, message_id, session.user_id())
3844        .await?;
3845
3846    broadcast(
3847        Some(session.connection_id),
3848        connection_ids,
3849        move |connection| {
3850            session.peer.send(connection, request.clone())?;
3851
3852            for notification_id in &existing_notification_ids {
3853                session.peer.send(
3854                    connection,
3855                    proto::DeleteNotification {
3856                        notification_id: (*notification_id).to_proto(),
3857                    },
3858                )?;
3859            }
3860
3861            Ok(())
3862        },
3863    );
3864    response.send(proto::Ack {})?;
3865    Ok(())
3866}
3867
3868async fn update_channel_message(
3869    request: proto::UpdateChannelMessage,
3870    response: Response<proto::UpdateChannelMessage>,
3871    session: Session,
3872) -> Result<()> {
3873    let channel_id = ChannelId::from_proto(request.channel_id);
3874    let message_id = MessageId::from_proto(request.message_id);
3875    let updated_at = OffsetDateTime::now_utc();
3876    let UpdatedChannelMessage {
3877        message_id,
3878        participant_connection_ids,
3879        notifications,
3880        reply_to_message_id,
3881        timestamp,
3882        deleted_mention_notification_ids,
3883        updated_mention_notifications,
3884    } = session
3885        .db()
3886        .await
3887        .update_channel_message(
3888            channel_id,
3889            message_id,
3890            session.user_id(),
3891            request.body.as_str(),
3892            &request.mentions,
3893            updated_at,
3894        )
3895        .await?;
3896
3897    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3898
3899    let message = proto::ChannelMessage {
3900        sender_id: session.user_id().to_proto(),
3901        id: message_id.to_proto(),
3902        body: request.body.clone(),
3903        mentions: request.mentions.clone(),
3904        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3905        nonce: Some(nonce),
3906        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3907        edited_at: Some(updated_at.unix_timestamp() as u64),
3908    };
3909
3910    response.send(proto::Ack {})?;
3911
3912    let pool = &*session.connection_pool().await;
3913    broadcast(
3914        Some(session.connection_id),
3915        participant_connection_ids,
3916        |connection| {
3917            session.peer.send(
3918                connection,
3919                proto::ChannelMessageUpdate {
3920                    channel_id: channel_id.to_proto(),
3921                    message: Some(message.clone()),
3922                },
3923            )?;
3924
3925            for notification_id in &deleted_mention_notification_ids {
3926                session.peer.send(
3927                    connection,
3928                    proto::DeleteNotification {
3929                        notification_id: (*notification_id).to_proto(),
3930                    },
3931                )?;
3932            }
3933
3934            for notification in &updated_mention_notifications {
3935                session.peer.send(
3936                    connection,
3937                    proto::UpdateNotification {
3938                        notification: Some(notification.clone()),
3939                    },
3940                )?;
3941            }
3942
3943            Ok(())
3944        },
3945    );
3946
3947    send_notifications(pool, &session.peer, notifications);
3948
3949    Ok(())
3950}
3951
3952/// Mark a channel message as read
3953async fn acknowledge_channel_message(
3954    request: proto::AckChannelMessage,
3955    session: Session,
3956) -> Result<()> {
3957    let channel_id = ChannelId::from_proto(request.channel_id);
3958    let message_id = MessageId::from_proto(request.message_id);
3959    let notifications = session
3960        .db()
3961        .await
3962        .observe_channel_message(channel_id, session.user_id(), message_id)
3963        .await?;
3964    send_notifications(
3965        &*session.connection_pool().await,
3966        &session.peer,
3967        notifications,
3968    );
3969    Ok(())
3970}
3971
3972/// Mark a buffer version as synced
3973async fn acknowledge_buffer_version(
3974    request: proto::AckBufferOperation,
3975    session: Session,
3976) -> Result<()> {
3977    let buffer_id = BufferId::from_proto(request.buffer_id);
3978    session
3979        .db()
3980        .await
3981        .observe_buffer_version(
3982            buffer_id,
3983            session.user_id(),
3984            request.epoch as i32,
3985            &request.version,
3986        )
3987        .await?;
3988    Ok(())
3989}
3990
3991/// Get a Supermaven API key for the user
3992async fn get_supermaven_api_key(
3993    _request: proto::GetSupermavenApiKey,
3994    response: Response<proto::GetSupermavenApiKey>,
3995    session: Session,
3996) -> Result<()> {
3997    let user_id: String = session.user_id().to_string();
3998    if !session.is_staff() {
3999        return Err(anyhow!("supermaven not enabled for this account"))?;
4000    }
4001
4002    let email = session.email().context("user must have an email")?;
4003
4004    let supermaven_admin_api = session
4005        .supermaven_client
4006        .as_ref()
4007        .context("supermaven not configured")?;
4008
4009    let result = supermaven_admin_api
4010        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4011        .await?;
4012
4013    response.send(proto::GetSupermavenApiKeyResponse {
4014        api_key: result.api_key,
4015    })?;
4016
4017    Ok(())
4018}
4019
4020/// Start receiving chat updates for a channel
4021async fn join_channel_chat(
4022    request: proto::JoinChannelChat,
4023    response: Response<proto::JoinChannelChat>,
4024    session: Session,
4025) -> Result<()> {
4026    let channel_id = ChannelId::from_proto(request.channel_id);
4027
4028    let db = session.db().await;
4029    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4030        .await?;
4031    let messages = db
4032        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4033        .await?;
4034    response.send(proto::JoinChannelChatResponse {
4035        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4036        messages,
4037    })?;
4038    Ok(())
4039}
4040
4041/// Stop receiving chat updates for a channel
4042async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4043    let channel_id = ChannelId::from_proto(request.channel_id);
4044    session
4045        .db()
4046        .await
4047        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4048        .await?;
4049    Ok(())
4050}
4051
4052/// Retrieve the chat history for a channel
4053async fn get_channel_messages(
4054    request: proto::GetChannelMessages,
4055    response: Response<proto::GetChannelMessages>,
4056    session: Session,
4057) -> Result<()> {
4058    let channel_id = ChannelId::from_proto(request.channel_id);
4059    let messages = session
4060        .db()
4061        .await
4062        .get_channel_messages(
4063            channel_id,
4064            session.user_id(),
4065            MESSAGE_COUNT_PER_PAGE,
4066            Some(MessageId::from_proto(request.before_message_id)),
4067        )
4068        .await?;
4069    response.send(proto::GetChannelMessagesResponse {
4070        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4071        messages,
4072    })?;
4073    Ok(())
4074}
4075
4076/// Retrieve specific chat messages
4077async fn get_channel_messages_by_id(
4078    request: proto::GetChannelMessagesById,
4079    response: Response<proto::GetChannelMessagesById>,
4080    session: Session,
4081) -> Result<()> {
4082    let message_ids = request
4083        .message_ids
4084        .iter()
4085        .map(|id| MessageId::from_proto(*id))
4086        .collect::<Vec<_>>();
4087    let messages = session
4088        .db()
4089        .await
4090        .get_channel_messages_by_id(session.user_id(), &message_ids)
4091        .await?;
4092    response.send(proto::GetChannelMessagesResponse {
4093        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4094        messages,
4095    })?;
4096    Ok(())
4097}
4098
4099/// Retrieve the current users notifications
4100async fn get_notifications(
4101    request: proto::GetNotifications,
4102    response: Response<proto::GetNotifications>,
4103    session: Session,
4104) -> Result<()> {
4105    let notifications = session
4106        .db()
4107        .await
4108        .get_notifications(
4109            session.user_id(),
4110            NOTIFICATION_COUNT_PER_PAGE,
4111            request.before_id.map(db::NotificationId::from_proto),
4112        )
4113        .await?;
4114    response.send(proto::GetNotificationsResponse {
4115        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4116        notifications,
4117    })?;
4118    Ok(())
4119}
4120
4121/// Mark notifications as read
4122async fn mark_notification_as_read(
4123    request: proto::MarkNotificationRead,
4124    response: Response<proto::MarkNotificationRead>,
4125    session: Session,
4126) -> Result<()> {
4127    let database = &session.db().await;
4128    let notifications = database
4129        .mark_notification_as_read_by_id(
4130            session.user_id(),
4131            NotificationId::from_proto(request.notification_id),
4132        )
4133        .await?;
4134    send_notifications(
4135        &*session.connection_pool().await,
4136        &session.peer,
4137        notifications,
4138    );
4139    response.send(proto::Ack {})?;
4140    Ok(())
4141}
4142
4143/// Get the current users information
4144async fn get_private_user_info(
4145    _request: proto::GetPrivateUserInfo,
4146    response: Response<proto::GetPrivateUserInfo>,
4147    session: Session,
4148) -> Result<()> {
4149    let db = session.db().await;
4150
4151    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4152    let user = db
4153        .get_user_by_id(session.user_id())
4154        .await?
4155        .context("user not found")?;
4156    let flags = db.get_user_flags(session.user_id()).await?;
4157
4158    response.send(proto::GetPrivateUserInfoResponse {
4159        metrics_id,
4160        staff: user.admin,
4161        flags,
4162        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4163    })?;
4164    Ok(())
4165}
4166
4167/// Accept the terms of service (tos) on behalf of the current user
4168async fn accept_terms_of_service(
4169    _request: proto::AcceptTermsOfService,
4170    response: Response<proto::AcceptTermsOfService>,
4171    session: Session,
4172) -> Result<()> {
4173    let db = session.db().await;
4174
4175    let accepted_tos_at = Utc::now();
4176    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4177        .await?;
4178
4179    response.send(proto::AcceptTermsOfServiceResponse {
4180        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4181    })?;
4182
4183    // When the user accepts the terms of service, we want to refresh their LLM
4184    // token to grant access.
4185    session
4186        .peer
4187        .send(session.connection_id, proto::RefreshLlmToken {})?;
4188
4189    Ok(())
4190}
4191
4192async fn get_llm_api_token(
4193    _request: proto::GetLlmToken,
4194    response: Response<proto::GetLlmToken>,
4195    session: Session,
4196) -> Result<()> {
4197    let db = session.db().await;
4198
4199    let flags = db.get_user_flags(session.user_id()).await?;
4200
4201    let user_id = session.user_id();
4202    let user = db
4203        .get_user_by_id(user_id)
4204        .await?
4205        .with_context(|| format!("user {user_id} not found"))?;
4206
4207    if user.accepted_tos_at.is_none() {
4208        Err(anyhow!("terms of service not accepted"))?
4209    }
4210
4211    let stripe_client = session
4212        .app_state
4213        .stripe_client
4214        .as_ref()
4215        .context("failed to retrieve Stripe client")?;
4216
4217    let stripe_billing = session
4218        .app_state
4219        .stripe_billing
4220        .as_ref()
4221        .context("failed to retrieve Stripe billing object")?;
4222
4223    let billing_customer = if let Some(billing_customer) =
4224        db.get_billing_customer_by_user_id(user.id).await?
4225    {
4226        billing_customer
4227    } else {
4228        let customer_id = stripe_billing
4229            .find_or_create_customer_by_email(user.email_address.as_deref())
4230            .await?;
4231
4232        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4233            .await?
4234            .context("billing customer not found")?
4235    };
4236
4237    let billing_subscription =
4238        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4239            billing_subscription
4240        } else {
4241            let stripe_customer_id =
4242                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4243
4244            let stripe_subscription = stripe_billing
4245                .subscribe_to_zed_free(stripe_customer_id)
4246                .await?;
4247
4248            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4249                billing_customer_id: billing_customer.id,
4250                kind: Some(SubscriptionKind::ZedFree),
4251                stripe_subscription_id: stripe_subscription.id.to_string(),
4252                stripe_subscription_status: stripe_subscription.status.into(),
4253                stripe_cancellation_reason: None,
4254                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4255                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4256            })
4257            .await?
4258        };
4259
4260    let billing_preferences = db.get_billing_preferences(user.id).await?;
4261
4262    let token = LlmTokenClaims::create(
4263        &user,
4264        session.is_staff(),
4265        billing_customer,
4266        billing_preferences,
4267        &flags,
4268        billing_subscription,
4269        session.system_id.clone(),
4270        &session.app_state.config,
4271    )?;
4272    response.send(proto::GetLlmTokenResponse { token })?;
4273    Ok(())
4274}
4275
4276fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4277    let message = match message {
4278        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4279        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4280        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4281        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4282        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4283            code: frame.code.into(),
4284            reason: frame.reason.as_str().to_owned().into(),
4285        })),
4286        // We should never receive a frame while reading the message, according
4287        // to the `tungstenite` maintainers:
4288        //
4289        // > It cannot occur when you read messages from the WebSocket, but it
4290        // > can be used when you want to send the raw frames (e.g. you want to
4291        // > send the frames to the WebSocket without composing the full message first).
4292        // >
4293        // > — https://github.com/snapview/tungstenite-rs/issues/268
4294        TungsteniteMessage::Frame(_) => {
4295            bail!("received an unexpected frame while reading the message")
4296        }
4297    };
4298
4299    Ok(message)
4300}
4301
4302fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4303    match message {
4304        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4305        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4306        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4307        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4308        AxumMessage::Close(frame) => {
4309            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4310                code: frame.code.into(),
4311                reason: frame.reason.as_ref().into(),
4312            }))
4313        }
4314    }
4315}
4316
4317fn notify_membership_updated(
4318    connection_pool: &mut ConnectionPool,
4319    result: MembershipUpdated,
4320    user_id: UserId,
4321    peer: &Peer,
4322) {
4323    for membership in &result.new_channels.channel_memberships {
4324        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4325    }
4326    for channel_id in &result.removed_channels {
4327        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4328    }
4329
4330    let user_channels_update = proto::UpdateUserChannels {
4331        channel_memberships: result
4332            .new_channels
4333            .channel_memberships
4334            .iter()
4335            .map(|cm| proto::ChannelMembership {
4336                channel_id: cm.channel_id.to_proto(),
4337                role: cm.role.into(),
4338            })
4339            .collect(),
4340        ..Default::default()
4341    };
4342
4343    let mut update = build_channels_update(result.new_channels);
4344    update.delete_channels = result
4345        .removed_channels
4346        .into_iter()
4347        .map(|id| id.to_proto())
4348        .collect();
4349    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4350
4351    for connection_id in connection_pool.user_connection_ids(user_id) {
4352        peer.send(connection_id, user_channels_update.clone())
4353            .trace_err();
4354        peer.send(connection_id, update.clone()).trace_err();
4355    }
4356}
4357
4358fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4359    proto::UpdateUserChannels {
4360        channel_memberships: channels
4361            .channel_memberships
4362            .iter()
4363            .map(|m| proto::ChannelMembership {
4364                channel_id: m.channel_id.to_proto(),
4365                role: m.role.into(),
4366            })
4367            .collect(),
4368        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4369        observed_channel_message_id: channels.observed_channel_messages.clone(),
4370    }
4371}
4372
4373fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4374    let mut update = proto::UpdateChannels::default();
4375
4376    for channel in channels.channels {
4377        update.channels.push(channel.to_proto());
4378    }
4379
4380    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4381    update.latest_channel_message_ids = channels.latest_channel_messages;
4382
4383    for (channel_id, participants) in channels.channel_participants {
4384        update
4385            .channel_participants
4386            .push(proto::ChannelParticipants {
4387                channel_id: channel_id.to_proto(),
4388                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4389            });
4390    }
4391
4392    for channel in channels.invited_channels {
4393        update.channel_invitations.push(channel.to_proto());
4394    }
4395
4396    update
4397}
4398
4399fn build_initial_contacts_update(
4400    contacts: Vec<db::Contact>,
4401    pool: &ConnectionPool,
4402) -> proto::UpdateContacts {
4403    let mut update = proto::UpdateContacts::default();
4404
4405    for contact in contacts {
4406        match contact {
4407            db::Contact::Accepted { user_id, busy } => {
4408                update.contacts.push(contact_for_user(user_id, busy, pool));
4409            }
4410            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4411            db::Contact::Incoming { user_id } => {
4412                update
4413                    .incoming_requests
4414                    .push(proto::IncomingContactRequest {
4415                        requester_id: user_id.to_proto(),
4416                    })
4417            }
4418        }
4419    }
4420
4421    update
4422}
4423
4424fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4425    proto::Contact {
4426        user_id: user_id.to_proto(),
4427        online: pool.is_user_online(user_id),
4428        busy,
4429    }
4430}
4431
4432fn room_updated(room: &proto::Room, peer: &Peer) {
4433    broadcast(
4434        None,
4435        room.participants
4436            .iter()
4437            .filter_map(|participant| Some(participant.peer_id?.into())),
4438        |peer_id| {
4439            peer.send(
4440                peer_id,
4441                proto::RoomUpdated {
4442                    room: Some(room.clone()),
4443                },
4444            )
4445        },
4446    );
4447}
4448
4449fn channel_updated(
4450    channel: &db::channel::Model,
4451    room: &proto::Room,
4452    peer: &Peer,
4453    pool: &ConnectionPool,
4454) {
4455    let participants = room
4456        .participants
4457        .iter()
4458        .map(|p| p.user_id)
4459        .collect::<Vec<_>>();
4460
4461    broadcast(
4462        None,
4463        pool.channel_connection_ids(channel.root_id())
4464            .filter_map(|(channel_id, role)| {
4465                role.can_see_channel(channel.visibility)
4466                    .then_some(channel_id)
4467            }),
4468        |peer_id| {
4469            peer.send(
4470                peer_id,
4471                proto::UpdateChannels {
4472                    channel_participants: vec![proto::ChannelParticipants {
4473                        channel_id: channel.id.to_proto(),
4474                        participant_user_ids: participants.clone(),
4475                    }],
4476                    ..Default::default()
4477                },
4478            )
4479        },
4480    );
4481}
4482
4483async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4484    let db = session.db().await;
4485
4486    let contacts = db.get_contacts(user_id).await?;
4487    let busy = db.is_user_busy(user_id).await?;
4488
4489    let pool = session.connection_pool().await;
4490    let updated_contact = contact_for_user(user_id, busy, &pool);
4491    for contact in contacts {
4492        if let db::Contact::Accepted {
4493            user_id: contact_user_id,
4494            ..
4495        } = contact
4496        {
4497            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4498                session
4499                    .peer
4500                    .send(
4501                        contact_conn_id,
4502                        proto::UpdateContacts {
4503                            contacts: vec![updated_contact.clone()],
4504                            remove_contacts: Default::default(),
4505                            incoming_requests: Default::default(),
4506                            remove_incoming_requests: Default::default(),
4507                            outgoing_requests: Default::default(),
4508                            remove_outgoing_requests: Default::default(),
4509                        },
4510                    )
4511                    .trace_err();
4512            }
4513        }
4514    }
4515    Ok(())
4516}
4517
4518async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4519    let mut contacts_to_update = HashSet::default();
4520
4521    let room_id;
4522    let canceled_calls_to_user_ids;
4523    let livekit_room;
4524    let delete_livekit_room;
4525    let room;
4526    let channel;
4527
4528    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4529        contacts_to_update.insert(session.user_id());
4530
4531        for project in left_room.left_projects.values() {
4532            project_left(project, session);
4533        }
4534
4535        room_id = RoomId::from_proto(left_room.room.id);
4536        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4537        livekit_room = mem::take(&mut left_room.room.livekit_room);
4538        delete_livekit_room = left_room.deleted;
4539        room = mem::take(&mut left_room.room);
4540        channel = mem::take(&mut left_room.channel);
4541
4542        room_updated(&room, &session.peer);
4543    } else {
4544        return Ok(());
4545    }
4546
4547    if let Some(channel) = channel {
4548        channel_updated(
4549            &channel,
4550            &room,
4551            &session.peer,
4552            &*session.connection_pool().await,
4553        );
4554    }
4555
4556    {
4557        let pool = session.connection_pool().await;
4558        for canceled_user_id in canceled_calls_to_user_ids {
4559            for connection_id in pool.user_connection_ids(canceled_user_id) {
4560                session
4561                    .peer
4562                    .send(
4563                        connection_id,
4564                        proto::CallCanceled {
4565                            room_id: room_id.to_proto(),
4566                        },
4567                    )
4568                    .trace_err();
4569            }
4570            contacts_to_update.insert(canceled_user_id);
4571        }
4572    }
4573
4574    for contact_user_id in contacts_to_update {
4575        update_user_contacts(contact_user_id, session).await?;
4576    }
4577
4578    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4579        live_kit
4580            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4581            .await
4582            .trace_err();
4583
4584        if delete_livekit_room {
4585            live_kit.delete_room(livekit_room).await.trace_err();
4586        }
4587    }
4588
4589    Ok(())
4590}
4591
4592async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4593    let left_channel_buffers = session
4594        .db()
4595        .await
4596        .leave_channel_buffers(session.connection_id)
4597        .await?;
4598
4599    for left_buffer in left_channel_buffers {
4600        channel_buffer_updated(
4601            session.connection_id,
4602            left_buffer.connections,
4603            &proto::UpdateChannelBufferCollaborators {
4604                channel_id: left_buffer.channel_id.to_proto(),
4605                collaborators: left_buffer.collaborators,
4606            },
4607            &session.peer,
4608        );
4609    }
4610
4611    Ok(())
4612}
4613
4614fn project_left(project: &db::LeftProject, session: &Session) {
4615    for connection_id in &project.connection_ids {
4616        if project.should_unshare {
4617            session
4618                .peer
4619                .send(
4620                    *connection_id,
4621                    proto::UnshareProject {
4622                        project_id: project.id.to_proto(),
4623                    },
4624                )
4625                .trace_err();
4626        } else {
4627            session
4628                .peer
4629                .send(
4630                    *connection_id,
4631                    proto::RemoveProjectCollaborator {
4632                        project_id: project.id.to_proto(),
4633                        peer_id: Some(session.connection_id.into()),
4634                    },
4635                )
4636                .trace_err();
4637        }
4638    }
4639}
4640
4641pub trait ResultExt {
4642    type Ok;
4643
4644    fn trace_err(self) -> Option<Self::Ok>;
4645}
4646
4647impl<T, E> ResultExt for Result<T, E>
4648where
4649    E: std::fmt::Debug,
4650{
4651    type Ok = T;
4652
4653    #[track_caller]
4654    fn trace_err(self) -> Option<T> {
4655        match self {
4656            Ok(value) => Some(value),
4657            Err(error) => {
4658                tracing::error!("{:?}", error);
4659                None
4660            }
4661        }
4662    }
4663}