rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{
   8    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
   9    MIN_ACCOUNT_AGE_FOR_LLM_USE,
  10};
  11use crate::stripe_client::StripeCustomerId;
  12use crate::{
  13    AppState, Error, Result, auth,
  14    db::{
  15        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  16        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  17        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  18        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  19    },
  20    executor::Executor,
  21};
  22use anyhow::{Context as _, anyhow, bail};
  23use async_tungstenite::tungstenite::{
  24    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  25};
  26use axum::{
  27    Extension, Router, TypedHeader,
  28    body::Body,
  29    extract::{
  30        ConnectInfo, WebSocketUpgrade,
  31        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  32    },
  33    headers::{Header, HeaderName},
  34    http::StatusCode,
  35    middleware,
  36    response::IntoResponse,
  37    routing::get,
  38};
  39use chrono::Utc;
  40use collections::{HashMap, HashSet};
  41pub use connection_pool::{ConnectionPool, ZedVersion};
  42use core::fmt::{self, Debug, Formatter};
  43use reqwest_client::ReqwestClient;
  44use rpc::proto::split_repository_update;
  45use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  46
  47use futures::{
  48    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  49    stream::FuturesUnordered,
  50};
  51use prometheus::{IntGauge, register_int_gauge};
  52use rpc::{
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54    proto::{
  55        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  56        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  57    },
  58};
  59use semantic_version::SemanticVersion;
  60use serde::{Serialize, Serializer};
  61use std::{
  62    any::TypeId,
  63    future::Future,
  64    marker::PhantomData,
  65    mem,
  66    net::SocketAddr,
  67    ops::{Deref, DerefMut},
  68    rc::Rc,
  69    sync::{
  70        Arc, OnceLock,
  71        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  72    },
  73    time::{Duration, Instant},
  74};
  75use time::OffsetDateTime;
  76use tokio::sync::{Semaphore, watch};
  77use tower::ServiceBuilder;
  78use tracing::{
  79    Instrument,
  80    field::{self},
  81    info_span, instrument,
  82};
  83
  84pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  85
  86// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  87pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  88
  89const MESSAGE_COUNT_PER_PAGE: usize = 100;
  90const MAX_MESSAGE_LEN: usize = 1024;
  91const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  92const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  93
  94static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  95
  96type MessageHandler =
  97    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  98
  99pub struct ConnectionGuard;
 100
 101impl ConnectionGuard {
 102    pub fn try_acquire() -> Result<Self, ()> {
 103        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 104        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 105            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 106            tracing::error!(
 107                "too many concurrent connections: {}",
 108                current_connections + 1
 109            );
 110            return Err(());
 111        }
 112        Ok(ConnectionGuard)
 113    }
 114}
 115
 116impl Drop for ConnectionGuard {
 117    fn drop(&mut self) {
 118        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 119    }
 120}
 121
 122struct Response<R> {
 123    peer: Arc<Peer>,
 124    receipt: Receipt<R>,
 125    responded: Arc<AtomicBool>,
 126}
 127
 128impl<R: RequestMessage> Response<R> {
 129    fn send(self, payload: R::Response) -> Result<()> {
 130        self.responded.store(true, SeqCst);
 131        self.peer.respond(self.receipt, payload)?;
 132        Ok(())
 133    }
 134}
 135
 136#[derive(Clone, Debug)]
 137pub enum Principal {
 138    User(User),
 139    Impersonated { user: User, admin: User },
 140}
 141
 142impl Principal {
 143    fn user(&self) -> &User {
 144        match self {
 145            Principal::User(user) => user,
 146            Principal::Impersonated { user, .. } => user,
 147        }
 148    }
 149
 150    fn update_span(&self, span: &tracing::Span) {
 151        match &self {
 152            Principal::User(user) => {
 153                span.record("user_id", user.id.0);
 154                span.record("login", &user.github_login);
 155            }
 156            Principal::Impersonated { user, admin } => {
 157                span.record("user_id", user.id.0);
 158                span.record("login", &user.github_login);
 159                span.record("impersonator", &admin.github_login);
 160            }
 161        }
 162    }
 163}
 164
 165#[derive(Clone)]
 166struct Session {
 167    principal: Principal,
 168    connection_id: ConnectionId,
 169    db: Arc<tokio::sync::Mutex<DbHandle>>,
 170    peer: Arc<Peer>,
 171    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 172    app_state: Arc<AppState>,
 173    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 174    /// The GeoIP country code for the user.
 175    #[allow(unused)]
 176    geoip_country_code: Option<String>,
 177    system_id: Option<String>,
 178    _executor: Executor,
 179}
 180
 181impl Session {
 182    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 183        #[cfg(test)]
 184        tokio::task::yield_now().await;
 185        let guard = self.db.lock().await;
 186        #[cfg(test)]
 187        tokio::task::yield_now().await;
 188        guard
 189    }
 190
 191    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 192        #[cfg(test)]
 193        tokio::task::yield_now().await;
 194        let guard = self.connection_pool.lock();
 195        ConnectionPoolGuard {
 196            guard,
 197            _not_send: PhantomData,
 198        }
 199    }
 200
 201    fn is_staff(&self) -> bool {
 202        match &self.principal {
 203            Principal::User(user) => user.admin,
 204            Principal::Impersonated { .. } => true,
 205        }
 206    }
 207
 208    fn user_id(&self) -> UserId {
 209        match &self.principal {
 210            Principal::User(user) => user.id,
 211            Principal::Impersonated { user, .. } => user.id,
 212        }
 213    }
 214
 215    pub fn email(&self) -> Option<String> {
 216        match &self.principal {
 217            Principal::User(user) => user.email_address.clone(),
 218            Principal::Impersonated { user, .. } => user.email_address.clone(),
 219        }
 220    }
 221}
 222
 223impl Debug for Session {
 224    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 225        let mut result = f.debug_struct("Session");
 226        match &self.principal {
 227            Principal::User(user) => {
 228                result.field("user", &user.github_login);
 229            }
 230            Principal::Impersonated { user, admin } => {
 231                result.field("user", &user.github_login);
 232                result.field("impersonator", &admin.github_login);
 233            }
 234        }
 235        result.field("connection_id", &self.connection_id).finish()
 236    }
 237}
 238
 239struct DbHandle(Arc<Database>);
 240
 241impl Deref for DbHandle {
 242    type Target = Database;
 243
 244    fn deref(&self) -> &Self::Target {
 245        self.0.as_ref()
 246    }
 247}
 248
 249pub struct Server {
 250    id: parking_lot::Mutex<ServerId>,
 251    peer: Arc<Peer>,
 252    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 253    app_state: Arc<AppState>,
 254    handlers: HashMap<TypeId, MessageHandler>,
 255    teardown: watch::Sender<bool>,
 256}
 257
 258pub(crate) struct ConnectionPoolGuard<'a> {
 259    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 260    _not_send: PhantomData<Rc<()>>,
 261}
 262
 263#[derive(Serialize)]
 264pub struct ServerSnapshot<'a> {
 265    peer: &'a Peer,
 266    #[serde(serialize_with = "serialize_deref")]
 267    connection_pool: ConnectionPoolGuard<'a>,
 268}
 269
 270pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 271where
 272    S: Serializer,
 273    T: Deref<Target = U>,
 274    U: Serialize,
 275{
 276    Serialize::serialize(value.deref(), serializer)
 277}
 278
 279impl Server {
 280    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 281        let mut server = Self {
 282            id: parking_lot::Mutex::new(id),
 283            peer: Peer::new(id.0 as u32),
 284            app_state: app_state.clone(),
 285            connection_pool: Default::default(),
 286            handlers: Default::default(),
 287            teardown: watch::channel(false).0,
 288        };
 289
 290        server
 291            .add_request_handler(ping)
 292            .add_request_handler(create_room)
 293            .add_request_handler(join_room)
 294            .add_request_handler(rejoin_room)
 295            .add_request_handler(leave_room)
 296            .add_request_handler(set_room_participant_role)
 297            .add_request_handler(call)
 298            .add_request_handler(cancel_call)
 299            .add_message_handler(decline_call)
 300            .add_request_handler(update_participant_location)
 301            .add_request_handler(share_project)
 302            .add_message_handler(unshare_project)
 303            .add_request_handler(join_project)
 304            .add_message_handler(leave_project)
 305            .add_request_handler(update_project)
 306            .add_request_handler(update_worktree)
 307            .add_request_handler(update_repository)
 308            .add_request_handler(remove_repository)
 309            .add_message_handler(start_language_server)
 310            .add_message_handler(update_language_server)
 311            .add_message_handler(update_diagnostic_summary)
 312            .add_message_handler(update_worktree_settings)
 313            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 314            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 316            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 317            .add_request_handler(forward_find_search_candidates_request)
 318            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 319            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 320            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 321            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 322            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 323            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 324            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 325            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 326            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 327            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 328            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 329            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 330            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 331            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 332            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 333            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 334            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 335            .add_request_handler(
 336                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 337            )
 338            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 339            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 340            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 341            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 342            .add_request_handler(
 343                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 344            )
 345            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 346            .add_request_handler(
 347                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 348            )
 349            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 350            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 353            )
 354            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 355            .add_request_handler(
 356                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 357            )
 358            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 359            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 360            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 361            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 362            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 363            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 364            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 365            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 366            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 367            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 368            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 369            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 370            .add_request_handler(
 371                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 372            )
 373            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 374            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 375            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 376            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 377            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 378            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 379            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 380            .add_message_handler(create_buffer_for_peer)
 381            .add_request_handler(update_buffer)
 382            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 383            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 384            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 385            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 386            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 387            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 388            .add_message_handler(
 389                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 390            )
 391            .add_request_handler(get_users)
 392            .add_request_handler(fuzzy_search_users)
 393            .add_request_handler(request_contact)
 394            .add_request_handler(remove_contact)
 395            .add_request_handler(respond_to_contact_request)
 396            .add_message_handler(subscribe_to_channels)
 397            .add_request_handler(create_channel)
 398            .add_request_handler(delete_channel)
 399            .add_request_handler(invite_channel_member)
 400            .add_request_handler(remove_channel_member)
 401            .add_request_handler(set_channel_member_role)
 402            .add_request_handler(set_channel_visibility)
 403            .add_request_handler(rename_channel)
 404            .add_request_handler(join_channel_buffer)
 405            .add_request_handler(leave_channel_buffer)
 406            .add_message_handler(update_channel_buffer)
 407            .add_request_handler(rejoin_channel_buffers)
 408            .add_request_handler(get_channel_members)
 409            .add_request_handler(respond_to_channel_invite)
 410            .add_request_handler(join_channel)
 411            .add_request_handler(join_channel_chat)
 412            .add_message_handler(leave_channel_chat)
 413            .add_request_handler(send_channel_message)
 414            .add_request_handler(remove_channel_message)
 415            .add_request_handler(update_channel_message)
 416            .add_request_handler(get_channel_messages)
 417            .add_request_handler(get_channel_messages_by_id)
 418            .add_request_handler(get_notifications)
 419            .add_request_handler(mark_notification_as_read)
 420            .add_request_handler(move_channel)
 421            .add_request_handler(reorder_channel)
 422            .add_request_handler(follow)
 423            .add_message_handler(unfollow)
 424            .add_message_handler(update_followers)
 425            .add_request_handler(get_private_user_info)
 426            .add_request_handler(get_llm_api_token)
 427            .add_request_handler(accept_terms_of_service)
 428            .add_message_handler(acknowledge_channel_message)
 429            .add_message_handler(acknowledge_buffer_version)
 430            .add_request_handler(get_supermaven_api_key)
 431            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 432            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 433            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 434            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 435            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 436            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 437            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 438            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 439            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 440            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 441            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 442            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 443            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 444            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 445            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 446            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 447            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 448            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 449            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 450            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 451            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 452            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 453            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 454            .add_message_handler(update_context);
 455
 456        Arc::new(server)
 457    }
 458
 459    pub async fn start(&self) -> Result<()> {
 460        let server_id = *self.id.lock();
 461        let app_state = self.app_state.clone();
 462        let peer = self.peer.clone();
 463        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 464        let pool = self.connection_pool.clone();
 465        let livekit_client = self.app_state.livekit_client.clone();
 466
 467        let span = info_span!("start server");
 468        self.app_state.executor.spawn_detached(
 469            async move {
 470                tracing::info!("waiting for cleanup timeout");
 471                timeout.await;
 472                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 473
 474                app_state
 475                    .db
 476                    .delete_stale_channel_chat_participants(
 477                        &app_state.config.zed_environment,
 478                        server_id,
 479                    )
 480                    .await
 481                    .trace_err();
 482
 483                if let Some((room_ids, channel_ids)) = app_state
 484                    .db
 485                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 486                    .await
 487                    .trace_err()
 488                {
 489                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 490                    tracing::info!(
 491                        stale_channel_buffer_count = channel_ids.len(),
 492                        "retrieved stale channel buffers"
 493                    );
 494
 495                    for channel_id in channel_ids {
 496                        if let Some(refreshed_channel_buffer) = app_state
 497                            .db
 498                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 499                            .await
 500                            .trace_err()
 501                        {
 502                            for connection_id in refreshed_channel_buffer.connection_ids {
 503                                peer.send(
 504                                    connection_id,
 505                                    proto::UpdateChannelBufferCollaborators {
 506                                        channel_id: channel_id.to_proto(),
 507                                        collaborators: refreshed_channel_buffer
 508                                            .collaborators
 509                                            .clone(),
 510                                    },
 511                                )
 512                                .trace_err();
 513                            }
 514                        }
 515                    }
 516
 517                    for room_id in room_ids {
 518                        let mut contacts_to_update = HashSet::default();
 519                        let mut canceled_calls_to_user_ids = Vec::new();
 520                        let mut livekit_room = String::new();
 521                        let mut delete_livekit_room = false;
 522
 523                        if let Some(mut refreshed_room) = app_state
 524                            .db
 525                            .clear_stale_room_participants(room_id, server_id)
 526                            .await
 527                            .trace_err()
 528                        {
 529                            tracing::info!(
 530                                room_id = room_id.0,
 531                                new_participant_count = refreshed_room.room.participants.len(),
 532                                "refreshed room"
 533                            );
 534                            room_updated(&refreshed_room.room, &peer);
 535                            if let Some(channel) = refreshed_room.channel.as_ref() {
 536                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 537                            }
 538                            contacts_to_update
 539                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 540                            contacts_to_update
 541                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 542                            canceled_calls_to_user_ids =
 543                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 544                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 545                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 546                        }
 547
 548                        {
 549                            let pool = pool.lock();
 550                            for canceled_user_id in canceled_calls_to_user_ids {
 551                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 552                                    peer.send(
 553                                        connection_id,
 554                                        proto::CallCanceled {
 555                                            room_id: room_id.to_proto(),
 556                                        },
 557                                    )
 558                                    .trace_err();
 559                                }
 560                            }
 561                        }
 562
 563                        for user_id in contacts_to_update {
 564                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 565                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 566                            if let Some((busy, contacts)) = busy.zip(contacts) {
 567                                let pool = pool.lock();
 568                                let updated_contact = contact_for_user(user_id, busy, &pool);
 569                                for contact in contacts {
 570                                    if let db::Contact::Accepted {
 571                                        user_id: contact_user_id,
 572                                        ..
 573                                    } = contact
 574                                    {
 575                                        for contact_conn_id in
 576                                            pool.user_connection_ids(contact_user_id)
 577                                        {
 578                                            peer.send(
 579                                                contact_conn_id,
 580                                                proto::UpdateContacts {
 581                                                    contacts: vec![updated_contact.clone()],
 582                                                    remove_contacts: Default::default(),
 583                                                    incoming_requests: Default::default(),
 584                                                    remove_incoming_requests: Default::default(),
 585                                                    outgoing_requests: Default::default(),
 586                                                    remove_outgoing_requests: Default::default(),
 587                                                },
 588                                            )
 589                                            .trace_err();
 590                                        }
 591                                    }
 592                                }
 593                            }
 594                        }
 595
 596                        if let Some(live_kit) = livekit_client.as_ref() {
 597                            if delete_livekit_room {
 598                                live_kit.delete_room(livekit_room).await.trace_err();
 599                            }
 600                        }
 601                    }
 602                }
 603
 604                app_state
 605                    .db
 606                    .delete_stale_channel_chat_participants(
 607                        &app_state.config.zed_environment,
 608                        server_id,
 609                    )
 610                    .await
 611                    .trace_err();
 612
 613                app_state
 614                    .db
 615                    .clear_old_worktree_entries(server_id)
 616                    .await
 617                    .trace_err();
 618
 619                app_state
 620                    .db
 621                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 622                    .await
 623                    .trace_err();
 624            }
 625            .instrument(span),
 626        );
 627        Ok(())
 628    }
 629
 630    pub fn teardown(&self) {
 631        self.peer.teardown();
 632        self.connection_pool.lock().reset();
 633        let _ = self.teardown.send(true);
 634    }
 635
 636    #[cfg(test)]
 637    pub fn reset(&self, id: ServerId) {
 638        self.teardown();
 639        *self.id.lock() = id;
 640        self.peer.reset(id.0 as u32);
 641        let _ = self.teardown.send(false);
 642    }
 643
 644    #[cfg(test)]
 645    pub fn id(&self) -> ServerId {
 646        *self.id.lock()
 647    }
 648
 649    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 650    where
 651        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 652        Fut: 'static + Send + Future<Output = Result<()>>,
 653        M: EnvelopedMessage,
 654    {
 655        let prev_handler = self.handlers.insert(
 656            TypeId::of::<M>(),
 657            Box::new(move |envelope, session| {
 658                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 659                let received_at = envelope.received_at;
 660                tracing::info!("message received");
 661                let start_time = Instant::now();
 662                let future = (handler)(*envelope, session);
 663                async move {
 664                    let result = future.await;
 665                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 666                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 667                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 668                    let payload_type = M::NAME;
 669
 670                    match result {
 671                        Err(error) => {
 672                            tracing::error!(
 673                                ?error,
 674                                total_duration_ms,
 675                                processing_duration_ms,
 676                                queue_duration_ms,
 677                                payload_type,
 678                                "error handling message"
 679                            )
 680                        }
 681                        Ok(()) => tracing::info!(
 682                            total_duration_ms,
 683                            processing_duration_ms,
 684                            queue_duration_ms,
 685                            "finished handling message"
 686                        ),
 687                    }
 688                }
 689                .boxed()
 690            }),
 691        );
 692        if prev_handler.is_some() {
 693            panic!("registered a handler for the same message twice");
 694        }
 695        self
 696    }
 697
 698    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 699    where
 700        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 701        Fut: 'static + Send + Future<Output = Result<()>>,
 702        M: EnvelopedMessage,
 703    {
 704        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 705        self
 706    }
 707
 708    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 709    where
 710        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 711        Fut: Send + Future<Output = Result<()>>,
 712        M: RequestMessage,
 713    {
 714        let handler = Arc::new(handler);
 715        self.add_handler(move |envelope, session| {
 716            let receipt = envelope.receipt();
 717            let handler = handler.clone();
 718            async move {
 719                let peer = session.peer.clone();
 720                let responded = Arc::new(AtomicBool::default());
 721                let response = Response {
 722                    peer: peer.clone(),
 723                    responded: responded.clone(),
 724                    receipt,
 725                };
 726                match (handler)(envelope.payload, response, session).await {
 727                    Ok(()) => {
 728                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 729                            Ok(())
 730                        } else {
 731                            Err(anyhow!("handler did not send a response"))?
 732                        }
 733                    }
 734                    Err(error) => {
 735                        let proto_err = match &error {
 736                            Error::Internal(err) => err.to_proto(),
 737                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 738                        };
 739                        peer.respond_with_error(receipt, proto_err)?;
 740                        Err(error)
 741                    }
 742                }
 743            }
 744        })
 745    }
 746
 747    pub fn handle_connection(
 748        self: &Arc<Self>,
 749        connection: Connection,
 750        address: String,
 751        principal: Principal,
 752        zed_version: ZedVersion,
 753        geoip_country_code: Option<String>,
 754        system_id: Option<String>,
 755        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 756        executor: Executor,
 757        connection_guard: Option<ConnectionGuard>,
 758    ) -> impl Future<Output = ()> + use<> {
 759        let this = self.clone();
 760        let span = info_span!("handle connection", %address,
 761            connection_id=field::Empty,
 762            user_id=field::Empty,
 763            login=field::Empty,
 764            impersonator=field::Empty,
 765            geoip_country_code=field::Empty
 766        );
 767        principal.update_span(&span);
 768        if let Some(country_code) = geoip_country_code.as_ref() {
 769            span.record("geoip_country_code", country_code);
 770        }
 771
 772        let mut teardown = self.teardown.subscribe();
 773        async move {
 774            if *teardown.borrow() {
 775                tracing::error!("server is tearing down");
 776                return
 777            }
 778
 779            let (connection_id, handle_io, mut incoming_rx) = this
 780                .peer
 781                .add_connection(connection, {
 782                    let executor = executor.clone();
 783                    move |duration| executor.sleep(duration)
 784                });
 785            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 786
 787            tracing::info!("connection opened");
 788
 789            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 790            let http_client = match ReqwestClient::user_agent(&user_agent) {
 791                Ok(http_client) => Arc::new(http_client),
 792                Err(error) => {
 793                    tracing::error!(?error, "failed to create HTTP client");
 794                    return;
 795                }
 796            };
 797
 798            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 799                    supermaven_admin_api_key.to_string(),
 800                    http_client.clone(),
 801                )));
 802
 803            let session = Session {
 804                principal: principal.clone(),
 805                connection_id,
 806                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 807                peer: this.peer.clone(),
 808                connection_pool: this.connection_pool.clone(),
 809                app_state: this.app_state.clone(),
 810                geoip_country_code,
 811                system_id,
 812                _executor: executor.clone(),
 813                supermaven_client,
 814            };
 815
 816            if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
 817                tracing::error!(?error, "failed to send initial client update");
 818                return;
 819            }
 820            drop(connection_guard);
 821
 822            let handle_io = handle_io.fuse();
 823            futures::pin_mut!(handle_io);
 824
 825            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 826            // This prevents deadlocks when e.g., client A performs a request to client B and
 827            // client B performs a request to client A. If both clients stop processing further
 828            // messages until their respective request completes, they won't have a chance to
 829            // respond to the other client's request and cause a deadlock.
 830            //
 831            // This arrangement ensures we will attempt to process earlier messages first, but fall
 832            // back to processing messages arrived later in the spirit of making progress.
 833            let mut foreground_message_handlers = FuturesUnordered::new();
 834            let concurrent_handlers = Arc::new(Semaphore::new(512));
 835            loop {
 836                let next_message = async {
 837                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 838                    let message = incoming_rx.next().await;
 839                    (permit, message)
 840                }.fuse();
 841                futures::pin_mut!(next_message);
 842                futures::select_biased! {
 843                    _ = teardown.changed().fuse() => return,
 844                    result = handle_io => {
 845                        if let Err(error) = result {
 846                            tracing::error!(?error, "error handling I/O");
 847                        }
 848                        break;
 849                    }
 850                    _ = foreground_message_handlers.next() => {}
 851                    next_message = next_message => {
 852                        let (permit, message) = next_message;
 853                        if let Some(message) = message {
 854                            let type_name = message.payload_type_name();
 855                            // note: we copy all the fields from the parent span so we can query them in the logs.
 856                            // (https://github.com/tokio-rs/tracing/issues/2670).
 857                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 858                                user_id=field::Empty,
 859                                login=field::Empty,
 860                                impersonator=field::Empty,
 861                            );
 862                            principal.update_span(&span);
 863                            let span_enter = span.enter();
 864                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 865                                let is_background = message.is_background();
 866                                let handle_message = (handler)(message, session.clone());
 867                                drop(span_enter);
 868
 869                                let handle_message = async move {
 870                                    handle_message.await;
 871                                    drop(permit);
 872                                }.instrument(span);
 873                                if is_background {
 874                                    executor.spawn_detached(handle_message);
 875                                } else {
 876                                    foreground_message_handlers.push(handle_message);
 877                                }
 878                            } else {
 879                                tracing::error!("no message handler");
 880                            }
 881                        } else {
 882                            tracing::info!("connection closed");
 883                            break;
 884                        }
 885                    }
 886                }
 887            }
 888
 889            drop(foreground_message_handlers);
 890            tracing::info!("signing out");
 891            if let Err(error) = connection_lost(session, teardown, executor).await {
 892                tracing::error!(?error, "error signing out");
 893            }
 894
 895        }.instrument(span)
 896    }
 897
 898    async fn send_initial_client_update(
 899        &self,
 900        connection_id: ConnectionId,
 901        zed_version: ZedVersion,
 902        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 903        session: &Session,
 904    ) -> Result<()> {
 905        self.peer.send(
 906            connection_id,
 907            proto::Hello {
 908                peer_id: Some(connection_id.into()),
 909            },
 910        )?;
 911        tracing::info!("sent hello message");
 912        if let Some(send_connection_id) = send_connection_id.take() {
 913            let _ = send_connection_id.send(connection_id);
 914        }
 915
 916        match &session.principal {
 917            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 918                if !user.connected_once {
 919                    self.peer.send(connection_id, proto::ShowContacts {})?;
 920                    self.app_state
 921                        .db
 922                        .set_user_connected_once(user.id, true)
 923                        .await?;
 924                }
 925
 926                update_user_plan(session).await?;
 927
 928                let contacts = self.app_state.db.get_contacts(user.id).await?;
 929
 930                {
 931                    let mut pool = self.connection_pool.lock();
 932                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 933                    self.peer.send(
 934                        connection_id,
 935                        build_initial_contacts_update(contacts, &pool),
 936                    )?;
 937                }
 938
 939                if should_auto_subscribe_to_channels(zed_version) {
 940                    subscribe_user_to_channels(user.id, session).await?;
 941                }
 942
 943                if let Some(incoming_call) =
 944                    self.app_state.db.incoming_call_for_user(user.id).await?
 945                {
 946                    self.peer.send(connection_id, incoming_call)?;
 947                }
 948
 949                update_user_contacts(user.id, session).await?;
 950            }
 951        }
 952
 953        Ok(())
 954    }
 955
 956    pub async fn invite_code_redeemed(
 957        self: &Arc<Self>,
 958        inviter_id: UserId,
 959        invitee_id: UserId,
 960    ) -> Result<()> {
 961        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 962            if let Some(code) = &user.invite_code {
 963                let pool = self.connection_pool.lock();
 964                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 965                for connection_id in pool.user_connection_ids(inviter_id) {
 966                    self.peer.send(
 967                        connection_id,
 968                        proto::UpdateContacts {
 969                            contacts: vec![invitee_contact.clone()],
 970                            ..Default::default()
 971                        },
 972                    )?;
 973                    self.peer.send(
 974                        connection_id,
 975                        proto::UpdateInviteInfo {
 976                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 977                            count: user.invite_count as u32,
 978                        },
 979                    )?;
 980                }
 981            }
 982        }
 983        Ok(())
 984    }
 985
 986    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 987        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 988            if let Some(invite_code) = &user.invite_code {
 989                let pool = self.connection_pool.lock();
 990                for connection_id in pool.user_connection_ids(user_id) {
 991                    self.peer.send(
 992                        connection_id,
 993                        proto::UpdateInviteInfo {
 994                            url: format!(
 995                                "{}{}",
 996                                self.app_state.config.invite_link_prefix, invite_code
 997                            ),
 998                            count: user.invite_count as u32,
 999                        },
1000                    )?;
1001                }
1002            }
1003        }
1004        Ok(())
1005    }
1006
1007    pub async fn update_plan_for_user(
1008        self: &Arc<Self>,
1009        user_id: UserId,
1010        update_user_plan: proto::UpdateUserPlan,
1011    ) -> Result<()> {
1012        let pool = self.connection_pool.lock();
1013        for connection_id in pool.user_connection_ids(user_id) {
1014            self.peer
1015                .send(connection_id, update_user_plan.clone())
1016                .trace_err();
1017        }
1018
1019        Ok(())
1020    }
1021
1022    /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1023    /// message on the Collab server.
1024    ///
1025    /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1026    pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1027        let user = self
1028            .app_state
1029            .db
1030            .get_user_by_id(user_id)
1031            .await?
1032            .context("user not found")?;
1033
1034        let update_user_plan = make_update_user_plan_message(
1035            &user,
1036            user.admin,
1037            &self.app_state.db,
1038            self.app_state.llm_db.clone(),
1039        )
1040        .await?;
1041
1042        self.update_plan_for_user(user_id, update_user_plan).await
1043    }
1044
1045    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1046        let pool = self.connection_pool.lock();
1047        for connection_id in pool.user_connection_ids(user_id) {
1048            self.peer
1049                .send(connection_id, proto::RefreshLlmToken {})
1050                .trace_err();
1051        }
1052    }
1053
1054    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1055        ServerSnapshot {
1056            connection_pool: ConnectionPoolGuard {
1057                guard: self.connection_pool.lock(),
1058                _not_send: PhantomData,
1059            },
1060            peer: &self.peer,
1061        }
1062    }
1063}
1064
1065impl Deref for ConnectionPoolGuard<'_> {
1066    type Target = ConnectionPool;
1067
1068    fn deref(&self) -> &Self::Target {
1069        &self.guard
1070    }
1071}
1072
1073impl DerefMut for ConnectionPoolGuard<'_> {
1074    fn deref_mut(&mut self) -> &mut Self::Target {
1075        &mut self.guard
1076    }
1077}
1078
1079impl Drop for ConnectionPoolGuard<'_> {
1080    fn drop(&mut self) {
1081        #[cfg(test)]
1082        self.check_invariants();
1083    }
1084}
1085
1086fn broadcast<F>(
1087    sender_id: Option<ConnectionId>,
1088    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1089    mut f: F,
1090) where
1091    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1092{
1093    for receiver_id in receiver_ids {
1094        if Some(receiver_id) != sender_id {
1095            if let Err(error) = f(receiver_id) {
1096                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1097            }
1098        }
1099    }
1100}
1101
1102pub struct ProtocolVersion(u32);
1103
1104impl Header for ProtocolVersion {
1105    fn name() -> &'static HeaderName {
1106        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1107        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1108    }
1109
1110    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1111    where
1112        Self: Sized,
1113        I: Iterator<Item = &'i axum::http::HeaderValue>,
1114    {
1115        let version = values
1116            .next()
1117            .ok_or_else(axum::headers::Error::invalid)?
1118            .to_str()
1119            .map_err(|_| axum::headers::Error::invalid())?
1120            .parse()
1121            .map_err(|_| axum::headers::Error::invalid())?;
1122        Ok(Self(version))
1123    }
1124
1125    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1126        values.extend([self.0.to_string().parse().unwrap()]);
1127    }
1128}
1129
1130pub struct AppVersionHeader(SemanticVersion);
1131impl Header for AppVersionHeader {
1132    fn name() -> &'static HeaderName {
1133        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1134        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1135    }
1136
1137    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1138    where
1139        Self: Sized,
1140        I: Iterator<Item = &'i axum::http::HeaderValue>,
1141    {
1142        let version = values
1143            .next()
1144            .ok_or_else(axum::headers::Error::invalid)?
1145            .to_str()
1146            .map_err(|_| axum::headers::Error::invalid())?
1147            .parse()
1148            .map_err(|_| axum::headers::Error::invalid())?;
1149        Ok(Self(version))
1150    }
1151
1152    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1153        values.extend([self.0.to_string().parse().unwrap()]);
1154    }
1155}
1156
1157pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1158    Router::new()
1159        .route("/rpc", get(handle_websocket_request))
1160        .layer(
1161            ServiceBuilder::new()
1162                .layer(Extension(server.app_state.clone()))
1163                .layer(middleware::from_fn(auth::validate_header)),
1164        )
1165        .route("/metrics", get(handle_metrics))
1166        .layer(Extension(server))
1167}
1168
1169pub async fn handle_websocket_request(
1170    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1171    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1172    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1173    Extension(server): Extension<Arc<Server>>,
1174    Extension(principal): Extension<Principal>,
1175    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1176    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1177    ws: WebSocketUpgrade,
1178) -> axum::response::Response {
1179    if protocol_version != rpc::PROTOCOL_VERSION {
1180        return (
1181            StatusCode::UPGRADE_REQUIRED,
1182            "client must be upgraded".to_string(),
1183        )
1184            .into_response();
1185    }
1186
1187    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1188        return (
1189            StatusCode::UPGRADE_REQUIRED,
1190            "no version header found".to_string(),
1191        )
1192            .into_response();
1193    };
1194
1195    if !version.can_collaborate() {
1196        return (
1197            StatusCode::UPGRADE_REQUIRED,
1198            "client must be upgraded".to_string(),
1199        )
1200            .into_response();
1201    }
1202
1203    let socket_address = socket_address.to_string();
1204
1205    // Acquire connection guard before WebSocket upgrade
1206    let connection_guard = match ConnectionGuard::try_acquire() {
1207        Ok(guard) => guard,
1208        Err(()) => {
1209            return (
1210                StatusCode::SERVICE_UNAVAILABLE,
1211                "Too many concurrent connections",
1212            )
1213                .into_response();
1214        }
1215    };
1216
1217    ws.on_upgrade(move |socket| {
1218        let socket = socket
1219            .map_ok(to_tungstenite_message)
1220            .err_into()
1221            .with(|message| async move { to_axum_message(message) });
1222        let connection = Connection::new(Box::pin(socket));
1223        async move {
1224            server
1225                .handle_connection(
1226                    connection,
1227                    socket_address,
1228                    principal,
1229                    version,
1230                    country_code_header.map(|header| header.to_string()),
1231                    system_id_header.map(|header| header.to_string()),
1232                    None,
1233                    Executor::Production,
1234                    Some(connection_guard),
1235                )
1236                .await;
1237        }
1238    })
1239}
1240
1241pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1242    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1243    let connections_metric = CONNECTIONS_METRIC
1244        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1245
1246    let connections = server
1247        .connection_pool
1248        .lock()
1249        .connections()
1250        .filter(|connection| !connection.admin)
1251        .count();
1252    connections_metric.set(connections as _);
1253
1254    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1255    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1256        register_int_gauge!(
1257            "shared_projects",
1258            "number of open projects with one or more guests"
1259        )
1260        .unwrap()
1261    });
1262
1263    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1264    shared_projects_metric.set(shared_projects as _);
1265
1266    let encoder = prometheus::TextEncoder::new();
1267    let metric_families = prometheus::gather();
1268    let encoded_metrics = encoder
1269        .encode_to_string(&metric_families)
1270        .map_err(|err| anyhow!("{err}"))?;
1271    Ok(encoded_metrics)
1272}
1273
1274#[instrument(err, skip(executor))]
1275async fn connection_lost(
1276    session: Session,
1277    mut teardown: watch::Receiver<bool>,
1278    executor: Executor,
1279) -> Result<()> {
1280    session.peer.disconnect(session.connection_id);
1281    session
1282        .connection_pool()
1283        .await
1284        .remove_connection(session.connection_id)?;
1285
1286    session
1287        .db()
1288        .await
1289        .connection_lost(session.connection_id)
1290        .await
1291        .trace_err();
1292
1293    futures::select_biased! {
1294        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1295
1296            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1297            leave_room_for_session(&session, session.connection_id).await.trace_err();
1298            leave_channel_buffers_for_session(&session)
1299                .await
1300                .trace_err();
1301
1302            if !session
1303                .connection_pool()
1304                .await
1305                .is_user_online(session.user_id())
1306            {
1307                let db = session.db().await;
1308                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1309                    room_updated(&room, &session.peer);
1310                }
1311            }
1312
1313            update_user_contacts(session.user_id(), &session).await?;
1314        },
1315        _ = teardown.changed().fuse() => {}
1316    }
1317
1318    Ok(())
1319}
1320
1321/// Acknowledges a ping from a client, used to keep the connection alive.
1322async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1323    response.send(proto::Ack {})?;
1324    Ok(())
1325}
1326
1327/// Creates a new room for calling (outside of channels)
1328async fn create_room(
1329    _request: proto::CreateRoom,
1330    response: Response<proto::CreateRoom>,
1331    session: Session,
1332) -> Result<()> {
1333    let livekit_room = nanoid::nanoid!(30);
1334
1335    let live_kit_connection_info = util::maybe!(async {
1336        let live_kit = session.app_state.livekit_client.as_ref();
1337        let live_kit = live_kit?;
1338        let user_id = session.user_id().to_string();
1339
1340        let token = live_kit
1341            .room_token(&livekit_room, &user_id.to_string())
1342            .trace_err()?;
1343
1344        Some(proto::LiveKitConnectionInfo {
1345            server_url: live_kit.url().into(),
1346            token,
1347            can_publish: true,
1348        })
1349    })
1350    .await;
1351
1352    let room = session
1353        .db()
1354        .await
1355        .create_room(session.user_id(), session.connection_id, &livekit_room)
1356        .await?;
1357
1358    response.send(proto::CreateRoomResponse {
1359        room: Some(room.clone()),
1360        live_kit_connection_info,
1361    })?;
1362
1363    update_user_contacts(session.user_id(), &session).await?;
1364    Ok(())
1365}
1366
1367/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1368async fn join_room(
1369    request: proto::JoinRoom,
1370    response: Response<proto::JoinRoom>,
1371    session: Session,
1372) -> Result<()> {
1373    let room_id = RoomId::from_proto(request.id);
1374
1375    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1376
1377    if let Some(channel_id) = channel_id {
1378        return join_channel_internal(channel_id, Box::new(response), session).await;
1379    }
1380
1381    let joined_room = {
1382        let room = session
1383            .db()
1384            .await
1385            .join_room(room_id, session.user_id(), session.connection_id)
1386            .await?;
1387        room_updated(&room.room, &session.peer);
1388        room.into_inner()
1389    };
1390
1391    for connection_id in session
1392        .connection_pool()
1393        .await
1394        .user_connection_ids(session.user_id())
1395    {
1396        session
1397            .peer
1398            .send(
1399                connection_id,
1400                proto::CallCanceled {
1401                    room_id: room_id.to_proto(),
1402                },
1403            )
1404            .trace_err();
1405    }
1406
1407    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1408    {
1409        live_kit
1410            .room_token(
1411                &joined_room.room.livekit_room,
1412                &session.user_id().to_string(),
1413            )
1414            .trace_err()
1415            .map(|token| proto::LiveKitConnectionInfo {
1416                server_url: live_kit.url().into(),
1417                token,
1418                can_publish: true,
1419            })
1420    } else {
1421        None
1422    };
1423
1424    response.send(proto::JoinRoomResponse {
1425        room: Some(joined_room.room),
1426        channel_id: None,
1427        live_kit_connection_info,
1428    })?;
1429
1430    update_user_contacts(session.user_id(), &session).await?;
1431    Ok(())
1432}
1433
1434/// Rejoin room is used to reconnect to a room after connection errors.
1435async fn rejoin_room(
1436    request: proto::RejoinRoom,
1437    response: Response<proto::RejoinRoom>,
1438    session: Session,
1439) -> Result<()> {
1440    let room;
1441    let channel;
1442    {
1443        let mut rejoined_room = session
1444            .db()
1445            .await
1446            .rejoin_room(request, session.user_id(), session.connection_id)
1447            .await?;
1448
1449        response.send(proto::RejoinRoomResponse {
1450            room: Some(rejoined_room.room.clone()),
1451            reshared_projects: rejoined_room
1452                .reshared_projects
1453                .iter()
1454                .map(|project| proto::ResharedProject {
1455                    id: project.id.to_proto(),
1456                    collaborators: project
1457                        .collaborators
1458                        .iter()
1459                        .map(|collaborator| collaborator.to_proto())
1460                        .collect(),
1461                })
1462                .collect(),
1463            rejoined_projects: rejoined_room
1464                .rejoined_projects
1465                .iter()
1466                .map(|rejoined_project| rejoined_project.to_proto())
1467                .collect(),
1468        })?;
1469        room_updated(&rejoined_room.room, &session.peer);
1470
1471        for project in &rejoined_room.reshared_projects {
1472            for collaborator in &project.collaborators {
1473                session
1474                    .peer
1475                    .send(
1476                        collaborator.connection_id,
1477                        proto::UpdateProjectCollaborator {
1478                            project_id: project.id.to_proto(),
1479                            old_peer_id: Some(project.old_connection_id.into()),
1480                            new_peer_id: Some(session.connection_id.into()),
1481                        },
1482                    )
1483                    .trace_err();
1484            }
1485
1486            broadcast(
1487                Some(session.connection_id),
1488                project
1489                    .collaborators
1490                    .iter()
1491                    .map(|collaborator| collaborator.connection_id),
1492                |connection_id| {
1493                    session.peer.forward_send(
1494                        session.connection_id,
1495                        connection_id,
1496                        proto::UpdateProject {
1497                            project_id: project.id.to_proto(),
1498                            worktrees: project.worktrees.clone(),
1499                        },
1500                    )
1501                },
1502            );
1503        }
1504
1505        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1506
1507        let rejoined_room = rejoined_room.into_inner();
1508
1509        room = rejoined_room.room;
1510        channel = rejoined_room.channel;
1511    }
1512
1513    if let Some(channel) = channel {
1514        channel_updated(
1515            &channel,
1516            &room,
1517            &session.peer,
1518            &*session.connection_pool().await,
1519        );
1520    }
1521
1522    update_user_contacts(session.user_id(), &session).await?;
1523    Ok(())
1524}
1525
1526fn notify_rejoined_projects(
1527    rejoined_projects: &mut Vec<RejoinedProject>,
1528    session: &Session,
1529) -> Result<()> {
1530    for project in rejoined_projects.iter() {
1531        for collaborator in &project.collaborators {
1532            session
1533                .peer
1534                .send(
1535                    collaborator.connection_id,
1536                    proto::UpdateProjectCollaborator {
1537                        project_id: project.id.to_proto(),
1538                        old_peer_id: Some(project.old_connection_id.into()),
1539                        new_peer_id: Some(session.connection_id.into()),
1540                    },
1541                )
1542                .trace_err();
1543        }
1544    }
1545
1546    for project in rejoined_projects {
1547        for worktree in mem::take(&mut project.worktrees) {
1548            // Stream this worktree's entries.
1549            let message = proto::UpdateWorktree {
1550                project_id: project.id.to_proto(),
1551                worktree_id: worktree.id,
1552                abs_path: worktree.abs_path.clone(),
1553                root_name: worktree.root_name,
1554                updated_entries: worktree.updated_entries,
1555                removed_entries: worktree.removed_entries,
1556                scan_id: worktree.scan_id,
1557                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1558                updated_repositories: worktree.updated_repositories,
1559                removed_repositories: worktree.removed_repositories,
1560            };
1561            for update in proto::split_worktree_update(message) {
1562                session.peer.send(session.connection_id, update)?;
1563            }
1564
1565            // Stream this worktree's diagnostics.
1566            for summary in worktree.diagnostic_summaries {
1567                session.peer.send(
1568                    session.connection_id,
1569                    proto::UpdateDiagnosticSummary {
1570                        project_id: project.id.to_proto(),
1571                        worktree_id: worktree.id,
1572                        summary: Some(summary),
1573                    },
1574                )?;
1575            }
1576
1577            for settings_file in worktree.settings_files {
1578                session.peer.send(
1579                    session.connection_id,
1580                    proto::UpdateWorktreeSettings {
1581                        project_id: project.id.to_proto(),
1582                        worktree_id: worktree.id,
1583                        path: settings_file.path,
1584                        content: Some(settings_file.content),
1585                        kind: Some(settings_file.kind.to_proto().into()),
1586                    },
1587                )?;
1588            }
1589        }
1590
1591        for repository in mem::take(&mut project.updated_repositories) {
1592            for update in split_repository_update(repository) {
1593                session.peer.send(session.connection_id, update)?;
1594            }
1595        }
1596
1597        for id in mem::take(&mut project.removed_repositories) {
1598            session.peer.send(
1599                session.connection_id,
1600                proto::RemoveRepository {
1601                    project_id: project.id.to_proto(),
1602                    id,
1603                },
1604            )?;
1605        }
1606    }
1607
1608    Ok(())
1609}
1610
1611/// leave room disconnects from the room.
1612async fn leave_room(
1613    _: proto::LeaveRoom,
1614    response: Response<proto::LeaveRoom>,
1615    session: Session,
1616) -> Result<()> {
1617    leave_room_for_session(&session, session.connection_id).await?;
1618    response.send(proto::Ack {})?;
1619    Ok(())
1620}
1621
1622/// Updates the permissions of someone else in the room.
1623async fn set_room_participant_role(
1624    request: proto::SetRoomParticipantRole,
1625    response: Response<proto::SetRoomParticipantRole>,
1626    session: Session,
1627) -> Result<()> {
1628    let user_id = UserId::from_proto(request.user_id);
1629    let role = ChannelRole::from(request.role());
1630
1631    let (livekit_room, can_publish) = {
1632        let room = session
1633            .db()
1634            .await
1635            .set_room_participant_role(
1636                session.user_id(),
1637                RoomId::from_proto(request.room_id),
1638                user_id,
1639                role,
1640            )
1641            .await?;
1642
1643        let livekit_room = room.livekit_room.clone();
1644        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1645        room_updated(&room, &session.peer);
1646        (livekit_room, can_publish)
1647    };
1648
1649    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1650        live_kit
1651            .update_participant(
1652                livekit_room.clone(),
1653                request.user_id.to_string(),
1654                livekit_api::proto::ParticipantPermission {
1655                    can_subscribe: true,
1656                    can_publish,
1657                    can_publish_data: can_publish,
1658                    hidden: false,
1659                    recorder: false,
1660                },
1661            )
1662            .await
1663            .trace_err();
1664    }
1665
1666    response.send(proto::Ack {})?;
1667    Ok(())
1668}
1669
1670/// Call someone else into the current room
1671async fn call(
1672    request: proto::Call,
1673    response: Response<proto::Call>,
1674    session: Session,
1675) -> Result<()> {
1676    let room_id = RoomId::from_proto(request.room_id);
1677    let calling_user_id = session.user_id();
1678    let calling_connection_id = session.connection_id;
1679    let called_user_id = UserId::from_proto(request.called_user_id);
1680    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1681    if !session
1682        .db()
1683        .await
1684        .has_contact(calling_user_id, called_user_id)
1685        .await?
1686    {
1687        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1688    }
1689
1690    let incoming_call = {
1691        let (room, incoming_call) = &mut *session
1692            .db()
1693            .await
1694            .call(
1695                room_id,
1696                calling_user_id,
1697                calling_connection_id,
1698                called_user_id,
1699                initial_project_id,
1700            )
1701            .await?;
1702        room_updated(room, &session.peer);
1703        mem::take(incoming_call)
1704    };
1705    update_user_contacts(called_user_id, &session).await?;
1706
1707    let mut calls = session
1708        .connection_pool()
1709        .await
1710        .user_connection_ids(called_user_id)
1711        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1712        .collect::<FuturesUnordered<_>>();
1713
1714    while let Some(call_response) = calls.next().await {
1715        match call_response.as_ref() {
1716            Ok(_) => {
1717                response.send(proto::Ack {})?;
1718                return Ok(());
1719            }
1720            Err(_) => {
1721                call_response.trace_err();
1722            }
1723        }
1724    }
1725
1726    {
1727        let room = session
1728            .db()
1729            .await
1730            .call_failed(room_id, called_user_id)
1731            .await?;
1732        room_updated(&room, &session.peer);
1733    }
1734    update_user_contacts(called_user_id, &session).await?;
1735
1736    Err(anyhow!("failed to ring user"))?
1737}
1738
1739/// Cancel an outgoing call.
1740async fn cancel_call(
1741    request: proto::CancelCall,
1742    response: Response<proto::CancelCall>,
1743    session: Session,
1744) -> Result<()> {
1745    let called_user_id = UserId::from_proto(request.called_user_id);
1746    let room_id = RoomId::from_proto(request.room_id);
1747    {
1748        let room = session
1749            .db()
1750            .await
1751            .cancel_call(room_id, session.connection_id, called_user_id)
1752            .await?;
1753        room_updated(&room, &session.peer);
1754    }
1755
1756    for connection_id in session
1757        .connection_pool()
1758        .await
1759        .user_connection_ids(called_user_id)
1760    {
1761        session
1762            .peer
1763            .send(
1764                connection_id,
1765                proto::CallCanceled {
1766                    room_id: room_id.to_proto(),
1767                },
1768            )
1769            .trace_err();
1770    }
1771    response.send(proto::Ack {})?;
1772
1773    update_user_contacts(called_user_id, &session).await?;
1774    Ok(())
1775}
1776
1777/// Decline an incoming call.
1778async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1779    let room_id = RoomId::from_proto(message.room_id);
1780    {
1781        let room = session
1782            .db()
1783            .await
1784            .decline_call(Some(room_id), session.user_id())
1785            .await?
1786            .context("declining call")?;
1787        room_updated(&room, &session.peer);
1788    }
1789
1790    for connection_id in session
1791        .connection_pool()
1792        .await
1793        .user_connection_ids(session.user_id())
1794    {
1795        session
1796            .peer
1797            .send(
1798                connection_id,
1799                proto::CallCanceled {
1800                    room_id: room_id.to_proto(),
1801                },
1802            )
1803            .trace_err();
1804    }
1805    update_user_contacts(session.user_id(), &session).await?;
1806    Ok(())
1807}
1808
1809/// Updates other participants in the room with your current location.
1810async fn update_participant_location(
1811    request: proto::UpdateParticipantLocation,
1812    response: Response<proto::UpdateParticipantLocation>,
1813    session: Session,
1814) -> Result<()> {
1815    let room_id = RoomId::from_proto(request.room_id);
1816    let location = request.location.context("invalid location")?;
1817
1818    let db = session.db().await;
1819    let room = db
1820        .update_room_participant_location(room_id, session.connection_id, location)
1821        .await?;
1822
1823    room_updated(&room, &session.peer);
1824    response.send(proto::Ack {})?;
1825    Ok(())
1826}
1827
1828/// Share a project into the room.
1829async fn share_project(
1830    request: proto::ShareProject,
1831    response: Response<proto::ShareProject>,
1832    session: Session,
1833) -> Result<()> {
1834    let (project_id, room) = &*session
1835        .db()
1836        .await
1837        .share_project(
1838            RoomId::from_proto(request.room_id),
1839            session.connection_id,
1840            &request.worktrees,
1841            request.is_ssh_project,
1842        )
1843        .await?;
1844    response.send(proto::ShareProjectResponse {
1845        project_id: project_id.to_proto(),
1846    })?;
1847    room_updated(room, &session.peer);
1848
1849    Ok(())
1850}
1851
1852/// Unshare a project from the room.
1853async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1854    let project_id = ProjectId::from_proto(message.project_id);
1855    unshare_project_internal(project_id, session.connection_id, &session).await
1856}
1857
1858async fn unshare_project_internal(
1859    project_id: ProjectId,
1860    connection_id: ConnectionId,
1861    session: &Session,
1862) -> Result<()> {
1863    let delete = {
1864        let room_guard = session
1865            .db()
1866            .await
1867            .unshare_project(project_id, connection_id)
1868            .await?;
1869
1870        let (delete, room, guest_connection_ids) = &*room_guard;
1871
1872        let message = proto::UnshareProject {
1873            project_id: project_id.to_proto(),
1874        };
1875
1876        broadcast(
1877            Some(connection_id),
1878            guest_connection_ids.iter().copied(),
1879            |conn_id| session.peer.send(conn_id, message.clone()),
1880        );
1881        if let Some(room) = room {
1882            room_updated(room, &session.peer);
1883        }
1884
1885        *delete
1886    };
1887
1888    if delete {
1889        let db = session.db().await;
1890        db.delete_project(project_id).await?;
1891    }
1892
1893    Ok(())
1894}
1895
1896/// Join someone elses shared project.
1897async fn join_project(
1898    request: proto::JoinProject,
1899    response: Response<proto::JoinProject>,
1900    session: Session,
1901) -> Result<()> {
1902    let project_id = ProjectId::from_proto(request.project_id);
1903
1904    tracing::info!(%project_id, "join project");
1905
1906    let db = session.db().await;
1907    let (project, replica_id) = &mut *db
1908        .join_project(
1909            project_id,
1910            session.connection_id,
1911            session.user_id(),
1912            request.committer_name.clone(),
1913            request.committer_email.clone(),
1914        )
1915        .await?;
1916    drop(db);
1917    tracing::info!(%project_id, "join remote project");
1918    let collaborators = project
1919        .collaborators
1920        .iter()
1921        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1922        .map(|collaborator| collaborator.to_proto())
1923        .collect::<Vec<_>>();
1924    let project_id = project.id;
1925    let guest_user_id = session.user_id();
1926
1927    let worktrees = project
1928        .worktrees
1929        .iter()
1930        .map(|(id, worktree)| proto::WorktreeMetadata {
1931            id: *id,
1932            root_name: worktree.root_name.clone(),
1933            visible: worktree.visible,
1934            abs_path: worktree.abs_path.clone(),
1935        })
1936        .collect::<Vec<_>>();
1937
1938    let add_project_collaborator = proto::AddProjectCollaborator {
1939        project_id: project_id.to_proto(),
1940        collaborator: Some(proto::Collaborator {
1941            peer_id: Some(session.connection_id.into()),
1942            replica_id: replica_id.0 as u32,
1943            user_id: guest_user_id.to_proto(),
1944            is_host: false,
1945            committer_name: request.committer_name.clone(),
1946            committer_email: request.committer_email.clone(),
1947        }),
1948    };
1949
1950    for collaborator in &collaborators {
1951        session
1952            .peer
1953            .send(
1954                collaborator.peer_id.unwrap().into(),
1955                add_project_collaborator.clone(),
1956            )
1957            .trace_err();
1958    }
1959
1960    // First, we send the metadata associated with each worktree.
1961    response.send(proto::JoinProjectResponse {
1962        project_id: project.id.0 as u64,
1963        worktrees: worktrees.clone(),
1964        replica_id: replica_id.0 as u32,
1965        collaborators: collaborators.clone(),
1966        language_servers: project.language_servers.clone(),
1967        role: project.role.into(),
1968    })?;
1969
1970    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1971        // Stream this worktree's entries.
1972        let message = proto::UpdateWorktree {
1973            project_id: project_id.to_proto(),
1974            worktree_id,
1975            abs_path: worktree.abs_path.clone(),
1976            root_name: worktree.root_name,
1977            updated_entries: worktree.entries,
1978            removed_entries: Default::default(),
1979            scan_id: worktree.scan_id,
1980            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1981            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1982            removed_repositories: Default::default(),
1983        };
1984        for update in proto::split_worktree_update(message) {
1985            session.peer.send(session.connection_id, update.clone())?;
1986        }
1987
1988        // Stream this worktree's diagnostics.
1989        for summary in worktree.diagnostic_summaries {
1990            session.peer.send(
1991                session.connection_id,
1992                proto::UpdateDiagnosticSummary {
1993                    project_id: project_id.to_proto(),
1994                    worktree_id: worktree.id,
1995                    summary: Some(summary),
1996                },
1997            )?;
1998        }
1999
2000        for settings_file in worktree.settings_files {
2001            session.peer.send(
2002                session.connection_id,
2003                proto::UpdateWorktreeSettings {
2004                    project_id: project_id.to_proto(),
2005                    worktree_id: worktree.id,
2006                    path: settings_file.path,
2007                    content: Some(settings_file.content),
2008                    kind: Some(settings_file.kind.to_proto() as i32),
2009                },
2010            )?;
2011        }
2012    }
2013
2014    for repository in mem::take(&mut project.repositories) {
2015        for update in split_repository_update(repository) {
2016            session.peer.send(session.connection_id, update)?;
2017        }
2018    }
2019
2020    for language_server in &project.language_servers {
2021        session.peer.send(
2022            session.connection_id,
2023            proto::UpdateLanguageServer {
2024                project_id: project_id.to_proto(),
2025                server_name: Some(language_server.name.clone()),
2026                language_server_id: language_server.id,
2027                variant: Some(
2028                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2029                        proto::LspDiskBasedDiagnosticsUpdated {},
2030                    ),
2031                ),
2032            },
2033        )?;
2034    }
2035
2036    Ok(())
2037}
2038
2039/// Leave someone elses shared project.
2040async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2041    let sender_id = session.connection_id;
2042    let project_id = ProjectId::from_proto(request.project_id);
2043    let db = session.db().await;
2044
2045    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2046    tracing::info!(
2047        %project_id,
2048        "leave project"
2049    );
2050
2051    project_left(project, &session);
2052    if let Some(room) = room {
2053        room_updated(room, &session.peer);
2054    }
2055
2056    Ok(())
2057}
2058
2059/// Updates other participants with changes to the project
2060async fn update_project(
2061    request: proto::UpdateProject,
2062    response: Response<proto::UpdateProject>,
2063    session: Session,
2064) -> Result<()> {
2065    let project_id = ProjectId::from_proto(request.project_id);
2066    let (room, guest_connection_ids) = &*session
2067        .db()
2068        .await
2069        .update_project(project_id, session.connection_id, &request.worktrees)
2070        .await?;
2071    broadcast(
2072        Some(session.connection_id),
2073        guest_connection_ids.iter().copied(),
2074        |connection_id| {
2075            session
2076                .peer
2077                .forward_send(session.connection_id, connection_id, request.clone())
2078        },
2079    );
2080    if let Some(room) = room {
2081        room_updated(room, &session.peer);
2082    }
2083    response.send(proto::Ack {})?;
2084
2085    Ok(())
2086}
2087
2088/// Updates other participants with changes to the worktree
2089async fn update_worktree(
2090    request: proto::UpdateWorktree,
2091    response: Response<proto::UpdateWorktree>,
2092    session: Session,
2093) -> Result<()> {
2094    let guest_connection_ids = session
2095        .db()
2096        .await
2097        .update_worktree(&request, session.connection_id)
2098        .await?;
2099
2100    broadcast(
2101        Some(session.connection_id),
2102        guest_connection_ids.iter().copied(),
2103        |connection_id| {
2104            session
2105                .peer
2106                .forward_send(session.connection_id, connection_id, request.clone())
2107        },
2108    );
2109    response.send(proto::Ack {})?;
2110    Ok(())
2111}
2112
2113async fn update_repository(
2114    request: proto::UpdateRepository,
2115    response: Response<proto::UpdateRepository>,
2116    session: Session,
2117) -> Result<()> {
2118    let guest_connection_ids = session
2119        .db()
2120        .await
2121        .update_repository(&request, session.connection_id)
2122        .await?;
2123
2124    broadcast(
2125        Some(session.connection_id),
2126        guest_connection_ids.iter().copied(),
2127        |connection_id| {
2128            session
2129                .peer
2130                .forward_send(session.connection_id, connection_id, request.clone())
2131        },
2132    );
2133    response.send(proto::Ack {})?;
2134    Ok(())
2135}
2136
2137async fn remove_repository(
2138    request: proto::RemoveRepository,
2139    response: Response<proto::RemoveRepository>,
2140    session: Session,
2141) -> Result<()> {
2142    let guest_connection_ids = session
2143        .db()
2144        .await
2145        .remove_repository(&request, session.connection_id)
2146        .await?;
2147
2148    broadcast(
2149        Some(session.connection_id),
2150        guest_connection_ids.iter().copied(),
2151        |connection_id| {
2152            session
2153                .peer
2154                .forward_send(session.connection_id, connection_id, request.clone())
2155        },
2156    );
2157    response.send(proto::Ack {})?;
2158    Ok(())
2159}
2160
2161/// Updates other participants with changes to the diagnostics
2162async fn update_diagnostic_summary(
2163    message: proto::UpdateDiagnosticSummary,
2164    session: Session,
2165) -> Result<()> {
2166    let guest_connection_ids = session
2167        .db()
2168        .await
2169        .update_diagnostic_summary(&message, session.connection_id)
2170        .await?;
2171
2172    broadcast(
2173        Some(session.connection_id),
2174        guest_connection_ids.iter().copied(),
2175        |connection_id| {
2176            session
2177                .peer
2178                .forward_send(session.connection_id, connection_id, message.clone())
2179        },
2180    );
2181
2182    Ok(())
2183}
2184
2185/// Updates other participants with changes to the worktree settings
2186async fn update_worktree_settings(
2187    message: proto::UpdateWorktreeSettings,
2188    session: Session,
2189) -> Result<()> {
2190    let guest_connection_ids = session
2191        .db()
2192        .await
2193        .update_worktree_settings(&message, session.connection_id)
2194        .await?;
2195
2196    broadcast(
2197        Some(session.connection_id),
2198        guest_connection_ids.iter().copied(),
2199        |connection_id| {
2200            session
2201                .peer
2202                .forward_send(session.connection_id, connection_id, message.clone())
2203        },
2204    );
2205
2206    Ok(())
2207}
2208
2209/// Notify other participants that a language server has started.
2210async fn start_language_server(
2211    request: proto::StartLanguageServer,
2212    session: Session,
2213) -> Result<()> {
2214    let guest_connection_ids = session
2215        .db()
2216        .await
2217        .start_language_server(&request, session.connection_id)
2218        .await?;
2219
2220    broadcast(
2221        Some(session.connection_id),
2222        guest_connection_ids.iter().copied(),
2223        |connection_id| {
2224            session
2225                .peer
2226                .forward_send(session.connection_id, connection_id, request.clone())
2227        },
2228    );
2229    Ok(())
2230}
2231
2232/// Notify other participants that a language server has changed.
2233async fn update_language_server(
2234    request: proto::UpdateLanguageServer,
2235    session: Session,
2236) -> Result<()> {
2237    let project_id = ProjectId::from_proto(request.project_id);
2238    let project_connection_ids = session
2239        .db()
2240        .await
2241        .project_connection_ids(project_id, session.connection_id, true)
2242        .await?;
2243    broadcast(
2244        Some(session.connection_id),
2245        project_connection_ids.iter().copied(),
2246        |connection_id| {
2247            session
2248                .peer
2249                .forward_send(session.connection_id, connection_id, request.clone())
2250        },
2251    );
2252    Ok(())
2253}
2254
2255/// forward a project request to the host. These requests should be read only
2256/// as guests are allowed to send them.
2257async fn forward_read_only_project_request<T>(
2258    request: T,
2259    response: Response<T>,
2260    session: Session,
2261) -> Result<()>
2262where
2263    T: EntityMessage + RequestMessage,
2264{
2265    let project_id = ProjectId::from_proto(request.remote_entity_id());
2266    let host_connection_id = session
2267        .db()
2268        .await
2269        .host_for_read_only_project_request(project_id, session.connection_id)
2270        .await?;
2271    let payload = session
2272        .peer
2273        .forward_request(session.connection_id, host_connection_id, request)
2274        .await?;
2275    response.send(payload)?;
2276    Ok(())
2277}
2278
2279async fn forward_find_search_candidates_request(
2280    request: proto::FindSearchCandidates,
2281    response: Response<proto::FindSearchCandidates>,
2282    session: Session,
2283) -> Result<()> {
2284    let project_id = ProjectId::from_proto(request.remote_entity_id());
2285    let host_connection_id = session
2286        .db()
2287        .await
2288        .host_for_read_only_project_request(project_id, session.connection_id)
2289        .await?;
2290    let payload = session
2291        .peer
2292        .forward_request(session.connection_id, host_connection_id, request)
2293        .await?;
2294    response.send(payload)?;
2295    Ok(())
2296}
2297
2298/// forward a project request to the host. These requests are disallowed
2299/// for guests.
2300async fn forward_mutating_project_request<T>(
2301    request: T,
2302    response: Response<T>,
2303    session: Session,
2304) -> Result<()>
2305where
2306    T: EntityMessage + RequestMessage,
2307{
2308    let project_id = ProjectId::from_proto(request.remote_entity_id());
2309
2310    let host_connection_id = session
2311        .db()
2312        .await
2313        .host_for_mutating_project_request(project_id, session.connection_id)
2314        .await?;
2315    let payload = session
2316        .peer
2317        .forward_request(session.connection_id, host_connection_id, request)
2318        .await?;
2319    response.send(payload)?;
2320    Ok(())
2321}
2322
2323/// Notify other participants that a new buffer has been created
2324async fn create_buffer_for_peer(
2325    request: proto::CreateBufferForPeer,
2326    session: Session,
2327) -> Result<()> {
2328    session
2329        .db()
2330        .await
2331        .check_user_is_project_host(
2332            ProjectId::from_proto(request.project_id),
2333            session.connection_id,
2334        )
2335        .await?;
2336    let peer_id = request.peer_id.context("invalid peer id")?;
2337    session
2338        .peer
2339        .forward_send(session.connection_id, peer_id.into(), request)?;
2340    Ok(())
2341}
2342
2343/// Notify other participants that a buffer has been updated. This is
2344/// allowed for guests as long as the update is limited to selections.
2345async fn update_buffer(
2346    request: proto::UpdateBuffer,
2347    response: Response<proto::UpdateBuffer>,
2348    session: Session,
2349) -> Result<()> {
2350    let project_id = ProjectId::from_proto(request.project_id);
2351    let mut capability = Capability::ReadOnly;
2352
2353    for op in request.operations.iter() {
2354        match op.variant {
2355            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2356            Some(_) => capability = Capability::ReadWrite,
2357        }
2358    }
2359
2360    let host = {
2361        let guard = session
2362            .db()
2363            .await
2364            .connections_for_buffer_update(project_id, session.connection_id, capability)
2365            .await?;
2366
2367        let (host, guests) = &*guard;
2368
2369        broadcast(
2370            Some(session.connection_id),
2371            guests.clone(),
2372            |connection_id| {
2373                session
2374                    .peer
2375                    .forward_send(session.connection_id, connection_id, request.clone())
2376            },
2377        );
2378
2379        *host
2380    };
2381
2382    if host != session.connection_id {
2383        session
2384            .peer
2385            .forward_request(session.connection_id, host, request.clone())
2386            .await?;
2387    }
2388
2389    response.send(proto::Ack {})?;
2390    Ok(())
2391}
2392
2393async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2394    let project_id = ProjectId::from_proto(message.project_id);
2395
2396    let operation = message.operation.as_ref().context("invalid operation")?;
2397    let capability = match operation.variant.as_ref() {
2398        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2399            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2400                match buffer_op.variant {
2401                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2402                        Capability::ReadOnly
2403                    }
2404                    _ => Capability::ReadWrite,
2405                }
2406            } else {
2407                Capability::ReadWrite
2408            }
2409        }
2410        Some(_) => Capability::ReadWrite,
2411        None => Capability::ReadOnly,
2412    };
2413
2414    let guard = session
2415        .db()
2416        .await
2417        .connections_for_buffer_update(project_id, session.connection_id, capability)
2418        .await?;
2419
2420    let (host, guests) = &*guard;
2421
2422    broadcast(
2423        Some(session.connection_id),
2424        guests.iter().chain([host]).copied(),
2425        |connection_id| {
2426            session
2427                .peer
2428                .forward_send(session.connection_id, connection_id, message.clone())
2429        },
2430    );
2431
2432    Ok(())
2433}
2434
2435/// Notify other participants that a project has been updated.
2436async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2437    request: T,
2438    session: Session,
2439) -> Result<()> {
2440    let project_id = ProjectId::from_proto(request.remote_entity_id());
2441    let project_connection_ids = session
2442        .db()
2443        .await
2444        .project_connection_ids(project_id, session.connection_id, false)
2445        .await?;
2446
2447    broadcast(
2448        Some(session.connection_id),
2449        project_connection_ids.iter().copied(),
2450        |connection_id| {
2451            session
2452                .peer
2453                .forward_send(session.connection_id, connection_id, request.clone())
2454        },
2455    );
2456    Ok(())
2457}
2458
2459/// Start following another user in a call.
2460async fn follow(
2461    request: proto::Follow,
2462    response: Response<proto::Follow>,
2463    session: Session,
2464) -> Result<()> {
2465    let room_id = RoomId::from_proto(request.room_id);
2466    let project_id = request.project_id.map(ProjectId::from_proto);
2467    let leader_id = request.leader_id.context("invalid leader id")?.into();
2468    let follower_id = session.connection_id;
2469
2470    session
2471        .db()
2472        .await
2473        .check_room_participants(room_id, leader_id, session.connection_id)
2474        .await?;
2475
2476    let response_payload = session
2477        .peer
2478        .forward_request(session.connection_id, leader_id, request)
2479        .await?;
2480    response.send(response_payload)?;
2481
2482    if let Some(project_id) = project_id {
2483        let room = session
2484            .db()
2485            .await
2486            .follow(room_id, project_id, leader_id, follower_id)
2487            .await?;
2488        room_updated(&room, &session.peer);
2489    }
2490
2491    Ok(())
2492}
2493
2494/// Stop following another user in a call.
2495async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2496    let room_id = RoomId::from_proto(request.room_id);
2497    let project_id = request.project_id.map(ProjectId::from_proto);
2498    let leader_id = request.leader_id.context("invalid leader id")?.into();
2499    let follower_id = session.connection_id;
2500
2501    session
2502        .db()
2503        .await
2504        .check_room_participants(room_id, leader_id, session.connection_id)
2505        .await?;
2506
2507    session
2508        .peer
2509        .forward_send(session.connection_id, leader_id, request)?;
2510
2511    if let Some(project_id) = project_id {
2512        let room = session
2513            .db()
2514            .await
2515            .unfollow(room_id, project_id, leader_id, follower_id)
2516            .await?;
2517        room_updated(&room, &session.peer);
2518    }
2519
2520    Ok(())
2521}
2522
2523/// Notify everyone following you of your current location.
2524async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2525    let room_id = RoomId::from_proto(request.room_id);
2526    let database = session.db.lock().await;
2527
2528    let connection_ids = if let Some(project_id) = request.project_id {
2529        let project_id = ProjectId::from_proto(project_id);
2530        database
2531            .project_connection_ids(project_id, session.connection_id, true)
2532            .await?
2533    } else {
2534        database
2535            .room_connection_ids(room_id, session.connection_id)
2536            .await?
2537    };
2538
2539    // For now, don't send view update messages back to that view's current leader.
2540    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2541        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2542        _ => None,
2543    });
2544
2545    for connection_id in connection_ids.iter().cloned() {
2546        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2547            session
2548                .peer
2549                .forward_send(session.connection_id, connection_id, request.clone())?;
2550        }
2551    }
2552    Ok(())
2553}
2554
2555/// Get public data about users.
2556async fn get_users(
2557    request: proto::GetUsers,
2558    response: Response<proto::GetUsers>,
2559    session: Session,
2560) -> Result<()> {
2561    let user_ids = request
2562        .user_ids
2563        .into_iter()
2564        .map(UserId::from_proto)
2565        .collect();
2566    let users = session
2567        .db()
2568        .await
2569        .get_users_by_ids(user_ids)
2570        .await?
2571        .into_iter()
2572        .map(|user| proto::User {
2573            id: user.id.to_proto(),
2574            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2575            github_login: user.github_login,
2576            name: user.name,
2577        })
2578        .collect();
2579    response.send(proto::UsersResponse { users })?;
2580    Ok(())
2581}
2582
2583/// Search for users (to invite) buy Github login
2584async fn fuzzy_search_users(
2585    request: proto::FuzzySearchUsers,
2586    response: Response<proto::FuzzySearchUsers>,
2587    session: Session,
2588) -> Result<()> {
2589    let query = request.query;
2590    let users = match query.len() {
2591        0 => vec![],
2592        1 | 2 => session
2593            .db()
2594            .await
2595            .get_user_by_github_login(&query)
2596            .await?
2597            .into_iter()
2598            .collect(),
2599        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2600    };
2601    let users = users
2602        .into_iter()
2603        .filter(|user| user.id != session.user_id())
2604        .map(|user| proto::User {
2605            id: user.id.to_proto(),
2606            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2607            github_login: user.github_login,
2608            name: user.name,
2609        })
2610        .collect();
2611    response.send(proto::UsersResponse { users })?;
2612    Ok(())
2613}
2614
2615/// Send a contact request to another user.
2616async fn request_contact(
2617    request: proto::RequestContact,
2618    response: Response<proto::RequestContact>,
2619    session: Session,
2620) -> Result<()> {
2621    let requester_id = session.user_id();
2622    let responder_id = UserId::from_proto(request.responder_id);
2623    if requester_id == responder_id {
2624        return Err(anyhow!("cannot add yourself as a contact"))?;
2625    }
2626
2627    let notifications = session
2628        .db()
2629        .await
2630        .send_contact_request(requester_id, responder_id)
2631        .await?;
2632
2633    // Update outgoing contact requests of requester
2634    let mut update = proto::UpdateContacts::default();
2635    update.outgoing_requests.push(responder_id.to_proto());
2636    for connection_id in session
2637        .connection_pool()
2638        .await
2639        .user_connection_ids(requester_id)
2640    {
2641        session.peer.send(connection_id, update.clone())?;
2642    }
2643
2644    // Update incoming contact requests of responder
2645    let mut update = proto::UpdateContacts::default();
2646    update
2647        .incoming_requests
2648        .push(proto::IncomingContactRequest {
2649            requester_id: requester_id.to_proto(),
2650        });
2651    let connection_pool = session.connection_pool().await;
2652    for connection_id in connection_pool.user_connection_ids(responder_id) {
2653        session.peer.send(connection_id, update.clone())?;
2654    }
2655
2656    send_notifications(&connection_pool, &session.peer, notifications);
2657
2658    response.send(proto::Ack {})?;
2659    Ok(())
2660}
2661
2662/// Accept or decline a contact request
2663async fn respond_to_contact_request(
2664    request: proto::RespondToContactRequest,
2665    response: Response<proto::RespondToContactRequest>,
2666    session: Session,
2667) -> Result<()> {
2668    let responder_id = session.user_id();
2669    let requester_id = UserId::from_proto(request.requester_id);
2670    let db = session.db().await;
2671    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2672        db.dismiss_contact_notification(responder_id, requester_id)
2673            .await?;
2674    } else {
2675        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2676
2677        let notifications = db
2678            .respond_to_contact_request(responder_id, requester_id, accept)
2679            .await?;
2680        let requester_busy = db.is_user_busy(requester_id).await?;
2681        let responder_busy = db.is_user_busy(responder_id).await?;
2682
2683        let pool = session.connection_pool().await;
2684        // Update responder with new contact
2685        let mut update = proto::UpdateContacts::default();
2686        if accept {
2687            update
2688                .contacts
2689                .push(contact_for_user(requester_id, requester_busy, &pool));
2690        }
2691        update
2692            .remove_incoming_requests
2693            .push(requester_id.to_proto());
2694        for connection_id in pool.user_connection_ids(responder_id) {
2695            session.peer.send(connection_id, update.clone())?;
2696        }
2697
2698        // Update requester with new contact
2699        let mut update = proto::UpdateContacts::default();
2700        if accept {
2701            update
2702                .contacts
2703                .push(contact_for_user(responder_id, responder_busy, &pool));
2704        }
2705        update
2706            .remove_outgoing_requests
2707            .push(responder_id.to_proto());
2708
2709        for connection_id in pool.user_connection_ids(requester_id) {
2710            session.peer.send(connection_id, update.clone())?;
2711        }
2712
2713        send_notifications(&pool, &session.peer, notifications);
2714    }
2715
2716    response.send(proto::Ack {})?;
2717    Ok(())
2718}
2719
2720/// Remove a contact.
2721async fn remove_contact(
2722    request: proto::RemoveContact,
2723    response: Response<proto::RemoveContact>,
2724    session: Session,
2725) -> Result<()> {
2726    let requester_id = session.user_id();
2727    let responder_id = UserId::from_proto(request.user_id);
2728    let db = session.db().await;
2729    let (contact_accepted, deleted_notification_id) =
2730        db.remove_contact(requester_id, responder_id).await?;
2731
2732    let pool = session.connection_pool().await;
2733    // Update outgoing contact requests of requester
2734    let mut update = proto::UpdateContacts::default();
2735    if contact_accepted {
2736        update.remove_contacts.push(responder_id.to_proto());
2737    } else {
2738        update
2739            .remove_outgoing_requests
2740            .push(responder_id.to_proto());
2741    }
2742    for connection_id in pool.user_connection_ids(requester_id) {
2743        session.peer.send(connection_id, update.clone())?;
2744    }
2745
2746    // Update incoming contact requests of responder
2747    let mut update = proto::UpdateContacts::default();
2748    if contact_accepted {
2749        update.remove_contacts.push(requester_id.to_proto());
2750    } else {
2751        update
2752            .remove_incoming_requests
2753            .push(requester_id.to_proto());
2754    }
2755    for connection_id in pool.user_connection_ids(responder_id) {
2756        session.peer.send(connection_id, update.clone())?;
2757        if let Some(notification_id) = deleted_notification_id {
2758            session.peer.send(
2759                connection_id,
2760                proto::DeleteNotification {
2761                    notification_id: notification_id.to_proto(),
2762                },
2763            )?;
2764        }
2765    }
2766
2767    response.send(proto::Ack {})?;
2768    Ok(())
2769}
2770
2771fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2772    version.0.minor() < 139
2773}
2774
2775async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2776    if is_staff {
2777        return Ok(proto::Plan::ZedPro);
2778    }
2779
2780    let subscription = db.get_active_billing_subscription(user_id).await?;
2781    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2782
2783    let plan = if let Some(subscription_kind) = subscription_kind {
2784        match subscription_kind {
2785            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2786            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2787            SubscriptionKind::ZedFree => proto::Plan::Free,
2788        }
2789    } else {
2790        proto::Plan::Free
2791    };
2792
2793    Ok(plan)
2794}
2795
2796async fn make_update_user_plan_message(
2797    user: &User,
2798    is_staff: bool,
2799    db: &Arc<Database>,
2800    llm_db: Option<Arc<LlmDatabase>>,
2801) -> Result<proto::UpdateUserPlan> {
2802    let feature_flags = db.get_user_flags(user.id).await?;
2803    let plan = current_plan(db, user.id, is_staff).await?;
2804    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2805    let billing_preferences = db.get_billing_preferences(user.id).await?;
2806
2807    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2808        let subscription = db.get_active_billing_subscription(user.id).await?;
2809
2810        let subscription_period =
2811            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2812
2813        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2814            llm_db
2815                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2816                .await?
2817        } else {
2818            None
2819        };
2820
2821        (subscription_period, usage)
2822    } else {
2823        (None, None)
2824    };
2825
2826    let bypass_account_age_check = feature_flags
2827        .iter()
2828        .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2829    let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2830        && !bypass_account_age_check
2831        && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2832
2833    Ok(proto::UpdateUserPlan {
2834        plan: plan.into(),
2835        trial_started_at: billing_customer
2836            .as_ref()
2837            .and_then(|billing_customer| billing_customer.trial_started_at)
2838            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2839        is_usage_based_billing_enabled: if is_staff {
2840            Some(true)
2841        } else {
2842            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2843        },
2844        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2845            proto::SubscriptionPeriod {
2846                started_at: started_at.timestamp() as u64,
2847                ended_at: ended_at.timestamp() as u64,
2848            }
2849        }),
2850        account_too_young: Some(account_too_young),
2851        has_overdue_invoices: billing_customer
2852            .map(|billing_customer| billing_customer.has_overdue_invoices),
2853        usage: Some(
2854            usage
2855                .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2856                .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2857        ),
2858    })
2859}
2860
2861fn model_requests_limit(
2862    plan: zed_llm_client::Plan,
2863    feature_flags: &Vec<String>,
2864) -> zed_llm_client::UsageLimit {
2865    match plan.model_requests_limit() {
2866        zed_llm_client::UsageLimit::Limited(limit) => {
2867            let limit = if plan == zed_llm_client::Plan::ZedProTrial
2868                && feature_flags
2869                    .iter()
2870                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2871            {
2872                1_000
2873            } else {
2874                limit
2875            };
2876
2877            zed_llm_client::UsageLimit::Limited(limit)
2878        }
2879        zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2880    }
2881}
2882
2883fn subscription_usage_to_proto(
2884    plan: proto::Plan,
2885    usage: crate::llm::db::subscription_usage::Model,
2886    feature_flags: &Vec<String>,
2887) -> proto::SubscriptionUsage {
2888    let plan = match plan {
2889        proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2890        proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2891        proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2892    };
2893
2894    proto::SubscriptionUsage {
2895        model_requests_usage_amount: usage.model_requests as u32,
2896        model_requests_usage_limit: Some(proto::UsageLimit {
2897            variant: Some(match model_requests_limit(plan, feature_flags) {
2898                zed_llm_client::UsageLimit::Limited(limit) => {
2899                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2900                        limit: limit as u32,
2901                    })
2902                }
2903                zed_llm_client::UsageLimit::Unlimited => {
2904                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2905                }
2906            }),
2907        }),
2908        edit_predictions_usage_amount: usage.edit_predictions as u32,
2909        edit_predictions_usage_limit: Some(proto::UsageLimit {
2910            variant: Some(match plan.edit_predictions_limit() {
2911                zed_llm_client::UsageLimit::Limited(limit) => {
2912                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2913                        limit: limit as u32,
2914                    })
2915                }
2916                zed_llm_client::UsageLimit::Unlimited => {
2917                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2918                }
2919            }),
2920        }),
2921    }
2922}
2923
2924fn make_default_subscription_usage(
2925    plan: proto::Plan,
2926    feature_flags: &Vec<String>,
2927) -> proto::SubscriptionUsage {
2928    let plan = match plan {
2929        proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2930        proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2931        proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2932    };
2933
2934    proto::SubscriptionUsage {
2935        model_requests_usage_amount: 0,
2936        model_requests_usage_limit: Some(proto::UsageLimit {
2937            variant: Some(match model_requests_limit(plan, feature_flags) {
2938                zed_llm_client::UsageLimit::Limited(limit) => {
2939                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2940                        limit: limit as u32,
2941                    })
2942                }
2943                zed_llm_client::UsageLimit::Unlimited => {
2944                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2945                }
2946            }),
2947        }),
2948        edit_predictions_usage_amount: 0,
2949        edit_predictions_usage_limit: Some(proto::UsageLimit {
2950            variant: Some(match plan.edit_predictions_limit() {
2951                zed_llm_client::UsageLimit::Limited(limit) => {
2952                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2953                        limit: limit as u32,
2954                    })
2955                }
2956                zed_llm_client::UsageLimit::Unlimited => {
2957                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2958                }
2959            }),
2960        }),
2961    }
2962}
2963
2964async fn update_user_plan(session: &Session) -> Result<()> {
2965    let db = session.db().await;
2966
2967    let update_user_plan = make_update_user_plan_message(
2968        session.principal.user(),
2969        session.is_staff(),
2970        &db.0,
2971        session.app_state.llm_db.clone(),
2972    )
2973    .await?;
2974
2975    session
2976        .peer
2977        .send(session.connection_id, update_user_plan)
2978        .trace_err();
2979
2980    Ok(())
2981}
2982
2983async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2984    subscribe_user_to_channels(session.user_id(), &session).await?;
2985    Ok(())
2986}
2987
2988async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2989    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2990    let mut pool = session.connection_pool().await;
2991    for membership in &channels_for_user.channel_memberships {
2992        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2993    }
2994    session.peer.send(
2995        session.connection_id,
2996        build_update_user_channels(&channels_for_user),
2997    )?;
2998    session.peer.send(
2999        session.connection_id,
3000        build_channels_update(channels_for_user),
3001    )?;
3002    Ok(())
3003}
3004
3005/// Creates a new channel.
3006async fn create_channel(
3007    request: proto::CreateChannel,
3008    response: Response<proto::CreateChannel>,
3009    session: Session,
3010) -> Result<()> {
3011    let db = session.db().await;
3012
3013    let parent_id = request.parent_id.map(ChannelId::from_proto);
3014    let (channel, membership) = db
3015        .create_channel(&request.name, parent_id, session.user_id())
3016        .await?;
3017
3018    let root_id = channel.root_id();
3019    let channel = Channel::from_model(channel);
3020
3021    response.send(proto::CreateChannelResponse {
3022        channel: Some(channel.to_proto()),
3023        parent_id: request.parent_id,
3024    })?;
3025
3026    let mut connection_pool = session.connection_pool().await;
3027    if let Some(membership) = membership {
3028        connection_pool.subscribe_to_channel(
3029            membership.user_id,
3030            membership.channel_id,
3031            membership.role,
3032        );
3033        let update = proto::UpdateUserChannels {
3034            channel_memberships: vec![proto::ChannelMembership {
3035                channel_id: membership.channel_id.to_proto(),
3036                role: membership.role.into(),
3037            }],
3038            ..Default::default()
3039        };
3040        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3041            session.peer.send(connection_id, update.clone())?;
3042        }
3043    }
3044
3045    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3046        if !role.can_see_channel(channel.visibility) {
3047            continue;
3048        }
3049
3050        let update = proto::UpdateChannels {
3051            channels: vec![channel.to_proto()],
3052            ..Default::default()
3053        };
3054        session.peer.send(connection_id, update.clone())?;
3055    }
3056
3057    Ok(())
3058}
3059
3060/// Delete a channel
3061async fn delete_channel(
3062    request: proto::DeleteChannel,
3063    response: Response<proto::DeleteChannel>,
3064    session: Session,
3065) -> Result<()> {
3066    let db = session.db().await;
3067
3068    let channel_id = request.channel_id;
3069    let (root_channel, removed_channels) = db
3070        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3071        .await?;
3072    response.send(proto::Ack {})?;
3073
3074    // Notify members of removed channels
3075    let mut update = proto::UpdateChannels::default();
3076    update
3077        .delete_channels
3078        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3079
3080    let connection_pool = session.connection_pool().await;
3081    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3082        session.peer.send(connection_id, update.clone())?;
3083    }
3084
3085    Ok(())
3086}
3087
3088/// Invite someone to join a channel.
3089async fn invite_channel_member(
3090    request: proto::InviteChannelMember,
3091    response: Response<proto::InviteChannelMember>,
3092    session: Session,
3093) -> Result<()> {
3094    let db = session.db().await;
3095    let channel_id = ChannelId::from_proto(request.channel_id);
3096    let invitee_id = UserId::from_proto(request.user_id);
3097    let InviteMemberResult {
3098        channel,
3099        notifications,
3100    } = db
3101        .invite_channel_member(
3102            channel_id,
3103            invitee_id,
3104            session.user_id(),
3105            request.role().into(),
3106        )
3107        .await?;
3108
3109    let update = proto::UpdateChannels {
3110        channel_invitations: vec![channel.to_proto()],
3111        ..Default::default()
3112    };
3113
3114    let connection_pool = session.connection_pool().await;
3115    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3116        session.peer.send(connection_id, update.clone())?;
3117    }
3118
3119    send_notifications(&connection_pool, &session.peer, notifications);
3120
3121    response.send(proto::Ack {})?;
3122    Ok(())
3123}
3124
3125/// remove someone from a channel
3126async fn remove_channel_member(
3127    request: proto::RemoveChannelMember,
3128    response: Response<proto::RemoveChannelMember>,
3129    session: Session,
3130) -> Result<()> {
3131    let db = session.db().await;
3132    let channel_id = ChannelId::from_proto(request.channel_id);
3133    let member_id = UserId::from_proto(request.user_id);
3134
3135    let RemoveChannelMemberResult {
3136        membership_update,
3137        notification_id,
3138    } = db
3139        .remove_channel_member(channel_id, member_id, session.user_id())
3140        .await?;
3141
3142    let mut connection_pool = session.connection_pool().await;
3143    notify_membership_updated(
3144        &mut connection_pool,
3145        membership_update,
3146        member_id,
3147        &session.peer,
3148    );
3149    for connection_id in connection_pool.user_connection_ids(member_id) {
3150        if let Some(notification_id) = notification_id {
3151            session
3152                .peer
3153                .send(
3154                    connection_id,
3155                    proto::DeleteNotification {
3156                        notification_id: notification_id.to_proto(),
3157                    },
3158                )
3159                .trace_err();
3160        }
3161    }
3162
3163    response.send(proto::Ack {})?;
3164    Ok(())
3165}
3166
3167/// Toggle the channel between public and private.
3168/// Care is taken to maintain the invariant that public channels only descend from public channels,
3169/// (though members-only channels can appear at any point in the hierarchy).
3170async fn set_channel_visibility(
3171    request: proto::SetChannelVisibility,
3172    response: Response<proto::SetChannelVisibility>,
3173    session: Session,
3174) -> Result<()> {
3175    let db = session.db().await;
3176    let channel_id = ChannelId::from_proto(request.channel_id);
3177    let visibility = request.visibility().into();
3178
3179    let channel_model = db
3180        .set_channel_visibility(channel_id, visibility, session.user_id())
3181        .await?;
3182    let root_id = channel_model.root_id();
3183    let channel = Channel::from_model(channel_model);
3184
3185    let mut connection_pool = session.connection_pool().await;
3186    for (user_id, role) in connection_pool
3187        .channel_user_ids(root_id)
3188        .collect::<Vec<_>>()
3189        .into_iter()
3190    {
3191        let update = if role.can_see_channel(channel.visibility) {
3192            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3193            proto::UpdateChannels {
3194                channels: vec![channel.to_proto()],
3195                ..Default::default()
3196            }
3197        } else {
3198            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3199            proto::UpdateChannels {
3200                delete_channels: vec![channel.id.to_proto()],
3201                ..Default::default()
3202            }
3203        };
3204
3205        for connection_id in connection_pool.user_connection_ids(user_id) {
3206            session.peer.send(connection_id, update.clone())?;
3207        }
3208    }
3209
3210    response.send(proto::Ack {})?;
3211    Ok(())
3212}
3213
3214/// Alter the role for a user in the channel.
3215async fn set_channel_member_role(
3216    request: proto::SetChannelMemberRole,
3217    response: Response<proto::SetChannelMemberRole>,
3218    session: Session,
3219) -> Result<()> {
3220    let db = session.db().await;
3221    let channel_id = ChannelId::from_proto(request.channel_id);
3222    let member_id = UserId::from_proto(request.user_id);
3223    let result = db
3224        .set_channel_member_role(
3225            channel_id,
3226            session.user_id(),
3227            member_id,
3228            request.role().into(),
3229        )
3230        .await?;
3231
3232    match result {
3233        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3234            let mut connection_pool = session.connection_pool().await;
3235            notify_membership_updated(
3236                &mut connection_pool,
3237                membership_update,
3238                member_id,
3239                &session.peer,
3240            )
3241        }
3242        db::SetMemberRoleResult::InviteUpdated(channel) => {
3243            let update = proto::UpdateChannels {
3244                channel_invitations: vec![channel.to_proto()],
3245                ..Default::default()
3246            };
3247
3248            for connection_id in session
3249                .connection_pool()
3250                .await
3251                .user_connection_ids(member_id)
3252            {
3253                session.peer.send(connection_id, update.clone())?;
3254            }
3255        }
3256    }
3257
3258    response.send(proto::Ack {})?;
3259    Ok(())
3260}
3261
3262/// Change the name of a channel
3263async fn rename_channel(
3264    request: proto::RenameChannel,
3265    response: Response<proto::RenameChannel>,
3266    session: Session,
3267) -> Result<()> {
3268    let db = session.db().await;
3269    let channel_id = ChannelId::from_proto(request.channel_id);
3270    let channel_model = db
3271        .rename_channel(channel_id, session.user_id(), &request.name)
3272        .await?;
3273    let root_id = channel_model.root_id();
3274    let channel = Channel::from_model(channel_model);
3275
3276    response.send(proto::RenameChannelResponse {
3277        channel: Some(channel.to_proto()),
3278    })?;
3279
3280    let connection_pool = session.connection_pool().await;
3281    let update = proto::UpdateChannels {
3282        channels: vec![channel.to_proto()],
3283        ..Default::default()
3284    };
3285    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3286        if role.can_see_channel(channel.visibility) {
3287            session.peer.send(connection_id, update.clone())?;
3288        }
3289    }
3290
3291    Ok(())
3292}
3293
3294/// Move a channel to a new parent.
3295async fn move_channel(
3296    request: proto::MoveChannel,
3297    response: Response<proto::MoveChannel>,
3298    session: Session,
3299) -> Result<()> {
3300    let channel_id = ChannelId::from_proto(request.channel_id);
3301    let to = ChannelId::from_proto(request.to);
3302
3303    let (root_id, channels) = session
3304        .db()
3305        .await
3306        .move_channel(channel_id, to, session.user_id())
3307        .await?;
3308
3309    let connection_pool = session.connection_pool().await;
3310    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3311        let channels = channels
3312            .iter()
3313            .filter_map(|channel| {
3314                if role.can_see_channel(channel.visibility) {
3315                    Some(channel.to_proto())
3316                } else {
3317                    None
3318                }
3319            })
3320            .collect::<Vec<_>>();
3321        if channels.is_empty() {
3322            continue;
3323        }
3324
3325        let update = proto::UpdateChannels {
3326            channels,
3327            ..Default::default()
3328        };
3329
3330        session.peer.send(connection_id, update.clone())?;
3331    }
3332
3333    response.send(Ack {})?;
3334    Ok(())
3335}
3336
3337async fn reorder_channel(
3338    request: proto::ReorderChannel,
3339    response: Response<proto::ReorderChannel>,
3340    session: Session,
3341) -> Result<()> {
3342    let channel_id = ChannelId::from_proto(request.channel_id);
3343    let direction = request.direction();
3344
3345    let updated_channels = session
3346        .db()
3347        .await
3348        .reorder_channel(channel_id, direction, session.user_id())
3349        .await?;
3350
3351    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3352        let connection_pool = session.connection_pool().await;
3353        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3354            let channels = updated_channels
3355                .iter()
3356                .filter_map(|channel| {
3357                    if role.can_see_channel(channel.visibility) {
3358                        Some(channel.to_proto())
3359                    } else {
3360                        None
3361                    }
3362                })
3363                .collect::<Vec<_>>();
3364
3365            if channels.is_empty() {
3366                continue;
3367            }
3368
3369            let update = proto::UpdateChannels {
3370                channels,
3371                ..Default::default()
3372            };
3373
3374            session.peer.send(connection_id, update.clone())?;
3375        }
3376    }
3377
3378    response.send(Ack {})?;
3379    Ok(())
3380}
3381
3382/// Get the list of channel members
3383async fn get_channel_members(
3384    request: proto::GetChannelMembers,
3385    response: Response<proto::GetChannelMembers>,
3386    session: Session,
3387) -> Result<()> {
3388    let db = session.db().await;
3389    let channel_id = ChannelId::from_proto(request.channel_id);
3390    let limit = if request.limit == 0 {
3391        u16::MAX as u64
3392    } else {
3393        request.limit
3394    };
3395    let (members, users) = db
3396        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3397        .await?;
3398    response.send(proto::GetChannelMembersResponse { members, users })?;
3399    Ok(())
3400}
3401
3402/// Accept or decline a channel invitation.
3403async fn respond_to_channel_invite(
3404    request: proto::RespondToChannelInvite,
3405    response: Response<proto::RespondToChannelInvite>,
3406    session: Session,
3407) -> Result<()> {
3408    let db = session.db().await;
3409    let channel_id = ChannelId::from_proto(request.channel_id);
3410    let RespondToChannelInvite {
3411        membership_update,
3412        notifications,
3413    } = db
3414        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3415        .await?;
3416
3417    let mut connection_pool = session.connection_pool().await;
3418    if let Some(membership_update) = membership_update {
3419        notify_membership_updated(
3420            &mut connection_pool,
3421            membership_update,
3422            session.user_id(),
3423            &session.peer,
3424        );
3425    } else {
3426        let update = proto::UpdateChannels {
3427            remove_channel_invitations: vec![channel_id.to_proto()],
3428            ..Default::default()
3429        };
3430
3431        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3432            session.peer.send(connection_id, update.clone())?;
3433        }
3434    };
3435
3436    send_notifications(&connection_pool, &session.peer, notifications);
3437
3438    response.send(proto::Ack {})?;
3439
3440    Ok(())
3441}
3442
3443/// Join the channels' room
3444async fn join_channel(
3445    request: proto::JoinChannel,
3446    response: Response<proto::JoinChannel>,
3447    session: Session,
3448) -> Result<()> {
3449    let channel_id = ChannelId::from_proto(request.channel_id);
3450    join_channel_internal(channel_id, Box::new(response), session).await
3451}
3452
3453trait JoinChannelInternalResponse {
3454    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3455}
3456impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3457    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3458        Response::<proto::JoinChannel>::send(self, result)
3459    }
3460}
3461impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3462    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3463        Response::<proto::JoinRoom>::send(self, result)
3464    }
3465}
3466
3467async fn join_channel_internal(
3468    channel_id: ChannelId,
3469    response: Box<impl JoinChannelInternalResponse>,
3470    session: Session,
3471) -> Result<()> {
3472    let joined_room = {
3473        let mut db = session.db().await;
3474        // If zed quits without leaving the room, and the user re-opens zed before the
3475        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3476        // room they were in.
3477        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3478            tracing::info!(
3479                stale_connection_id = %connection,
3480                "cleaning up stale connection",
3481            );
3482            drop(db);
3483            leave_room_for_session(&session, connection).await?;
3484            db = session.db().await;
3485        }
3486
3487        let (joined_room, membership_updated, role) = db
3488            .join_channel(channel_id, session.user_id(), session.connection_id)
3489            .await?;
3490
3491        let live_kit_connection_info =
3492            session
3493                .app_state
3494                .livekit_client
3495                .as_ref()
3496                .and_then(|live_kit| {
3497                    let (can_publish, token) = if role == ChannelRole::Guest {
3498                        (
3499                            false,
3500                            live_kit
3501                                .guest_token(
3502                                    &joined_room.room.livekit_room,
3503                                    &session.user_id().to_string(),
3504                                )
3505                                .trace_err()?,
3506                        )
3507                    } else {
3508                        (
3509                            true,
3510                            live_kit
3511                                .room_token(
3512                                    &joined_room.room.livekit_room,
3513                                    &session.user_id().to_string(),
3514                                )
3515                                .trace_err()?,
3516                        )
3517                    };
3518
3519                    Some(LiveKitConnectionInfo {
3520                        server_url: live_kit.url().into(),
3521                        token,
3522                        can_publish,
3523                    })
3524                });
3525
3526        response.send(proto::JoinRoomResponse {
3527            room: Some(joined_room.room.clone()),
3528            channel_id: joined_room
3529                .channel
3530                .as_ref()
3531                .map(|channel| channel.id.to_proto()),
3532            live_kit_connection_info,
3533        })?;
3534
3535        let mut connection_pool = session.connection_pool().await;
3536        if let Some(membership_updated) = membership_updated {
3537            notify_membership_updated(
3538                &mut connection_pool,
3539                membership_updated,
3540                session.user_id(),
3541                &session.peer,
3542            );
3543        }
3544
3545        room_updated(&joined_room.room, &session.peer);
3546
3547        joined_room
3548    };
3549
3550    channel_updated(
3551        &joined_room.channel.context("channel not returned")?,
3552        &joined_room.room,
3553        &session.peer,
3554        &*session.connection_pool().await,
3555    );
3556
3557    update_user_contacts(session.user_id(), &session).await?;
3558    Ok(())
3559}
3560
3561/// Start editing the channel notes
3562async fn join_channel_buffer(
3563    request: proto::JoinChannelBuffer,
3564    response: Response<proto::JoinChannelBuffer>,
3565    session: Session,
3566) -> Result<()> {
3567    let db = session.db().await;
3568    let channel_id = ChannelId::from_proto(request.channel_id);
3569
3570    let open_response = db
3571        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3572        .await?;
3573
3574    let collaborators = open_response.collaborators.clone();
3575    response.send(open_response)?;
3576
3577    let update = UpdateChannelBufferCollaborators {
3578        channel_id: channel_id.to_proto(),
3579        collaborators: collaborators.clone(),
3580    };
3581    channel_buffer_updated(
3582        session.connection_id,
3583        collaborators
3584            .iter()
3585            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3586        &update,
3587        &session.peer,
3588    );
3589
3590    Ok(())
3591}
3592
3593/// Edit the channel notes
3594async fn update_channel_buffer(
3595    request: proto::UpdateChannelBuffer,
3596    session: Session,
3597) -> Result<()> {
3598    let db = session.db().await;
3599    let channel_id = ChannelId::from_proto(request.channel_id);
3600
3601    let (collaborators, epoch, version) = db
3602        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3603        .await?;
3604
3605    channel_buffer_updated(
3606        session.connection_id,
3607        collaborators.clone(),
3608        &proto::UpdateChannelBuffer {
3609            channel_id: channel_id.to_proto(),
3610            operations: request.operations,
3611        },
3612        &session.peer,
3613    );
3614
3615    let pool = &*session.connection_pool().await;
3616
3617    let non_collaborators =
3618        pool.channel_connection_ids(channel_id)
3619            .filter_map(|(connection_id, _)| {
3620                if collaborators.contains(&connection_id) {
3621                    None
3622                } else {
3623                    Some(connection_id)
3624                }
3625            });
3626
3627    broadcast(None, non_collaborators, |peer_id| {
3628        session.peer.send(
3629            peer_id,
3630            proto::UpdateChannels {
3631                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3632                    channel_id: channel_id.to_proto(),
3633                    epoch: epoch as u64,
3634                    version: version.clone(),
3635                }],
3636                ..Default::default()
3637            },
3638        )
3639    });
3640
3641    Ok(())
3642}
3643
3644/// Rejoin the channel notes after a connection blip
3645async fn rejoin_channel_buffers(
3646    request: proto::RejoinChannelBuffers,
3647    response: Response<proto::RejoinChannelBuffers>,
3648    session: Session,
3649) -> Result<()> {
3650    let db = session.db().await;
3651    let buffers = db
3652        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3653        .await?;
3654
3655    for rejoined_buffer in &buffers {
3656        let collaborators_to_notify = rejoined_buffer
3657            .buffer
3658            .collaborators
3659            .iter()
3660            .filter_map(|c| Some(c.peer_id?.into()));
3661        channel_buffer_updated(
3662            session.connection_id,
3663            collaborators_to_notify,
3664            &proto::UpdateChannelBufferCollaborators {
3665                channel_id: rejoined_buffer.buffer.channel_id,
3666                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3667            },
3668            &session.peer,
3669        );
3670    }
3671
3672    response.send(proto::RejoinChannelBuffersResponse {
3673        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3674    })?;
3675
3676    Ok(())
3677}
3678
3679/// Stop editing the channel notes
3680async fn leave_channel_buffer(
3681    request: proto::LeaveChannelBuffer,
3682    response: Response<proto::LeaveChannelBuffer>,
3683    session: Session,
3684) -> Result<()> {
3685    let db = session.db().await;
3686    let channel_id = ChannelId::from_proto(request.channel_id);
3687
3688    let left_buffer = db
3689        .leave_channel_buffer(channel_id, session.connection_id)
3690        .await?;
3691
3692    response.send(Ack {})?;
3693
3694    channel_buffer_updated(
3695        session.connection_id,
3696        left_buffer.connections,
3697        &proto::UpdateChannelBufferCollaborators {
3698            channel_id: channel_id.to_proto(),
3699            collaborators: left_buffer.collaborators,
3700        },
3701        &session.peer,
3702    );
3703
3704    Ok(())
3705}
3706
3707fn channel_buffer_updated<T: EnvelopedMessage>(
3708    sender_id: ConnectionId,
3709    collaborators: impl IntoIterator<Item = ConnectionId>,
3710    message: &T,
3711    peer: &Peer,
3712) {
3713    broadcast(Some(sender_id), collaborators, |peer_id| {
3714        peer.send(peer_id, message.clone())
3715    });
3716}
3717
3718fn send_notifications(
3719    connection_pool: &ConnectionPool,
3720    peer: &Peer,
3721    notifications: db::NotificationBatch,
3722) {
3723    for (user_id, notification) in notifications {
3724        for connection_id in connection_pool.user_connection_ids(user_id) {
3725            if let Err(error) = peer.send(
3726                connection_id,
3727                proto::AddNotification {
3728                    notification: Some(notification.clone()),
3729                },
3730            ) {
3731                tracing::error!(
3732                    "failed to send notification to {:?} {}",
3733                    connection_id,
3734                    error
3735                );
3736            }
3737        }
3738    }
3739}
3740
3741/// Send a message to the channel
3742async fn send_channel_message(
3743    request: proto::SendChannelMessage,
3744    response: Response<proto::SendChannelMessage>,
3745    session: Session,
3746) -> Result<()> {
3747    // Validate the message body.
3748    let body = request.body.trim().to_string();
3749    if body.len() > MAX_MESSAGE_LEN {
3750        return Err(anyhow!("message is too long"))?;
3751    }
3752    if body.is_empty() {
3753        return Err(anyhow!("message can't be blank"))?;
3754    }
3755
3756    // TODO: adjust mentions if body is trimmed
3757
3758    let timestamp = OffsetDateTime::now_utc();
3759    let nonce = request.nonce.context("nonce can't be blank")?;
3760
3761    let channel_id = ChannelId::from_proto(request.channel_id);
3762    let CreatedChannelMessage {
3763        message_id,
3764        participant_connection_ids,
3765        notifications,
3766    } = session
3767        .db()
3768        .await
3769        .create_channel_message(
3770            channel_id,
3771            session.user_id(),
3772            &body,
3773            &request.mentions,
3774            timestamp,
3775            nonce.clone().into(),
3776            request.reply_to_message_id.map(MessageId::from_proto),
3777        )
3778        .await?;
3779
3780    let message = proto::ChannelMessage {
3781        sender_id: session.user_id().to_proto(),
3782        id: message_id.to_proto(),
3783        body,
3784        mentions: request.mentions,
3785        timestamp: timestamp.unix_timestamp() as u64,
3786        nonce: Some(nonce),
3787        reply_to_message_id: request.reply_to_message_id,
3788        edited_at: None,
3789    };
3790    broadcast(
3791        Some(session.connection_id),
3792        participant_connection_ids.clone(),
3793        |connection| {
3794            session.peer.send(
3795                connection,
3796                proto::ChannelMessageSent {
3797                    channel_id: channel_id.to_proto(),
3798                    message: Some(message.clone()),
3799                },
3800            )
3801        },
3802    );
3803    response.send(proto::SendChannelMessageResponse {
3804        message: Some(message),
3805    })?;
3806
3807    let pool = &*session.connection_pool().await;
3808    let non_participants =
3809        pool.channel_connection_ids(channel_id)
3810            .filter_map(|(connection_id, _)| {
3811                if participant_connection_ids.contains(&connection_id) {
3812                    None
3813                } else {
3814                    Some(connection_id)
3815                }
3816            });
3817    broadcast(None, non_participants, |peer_id| {
3818        session.peer.send(
3819            peer_id,
3820            proto::UpdateChannels {
3821                latest_channel_message_ids: vec![proto::ChannelMessageId {
3822                    channel_id: channel_id.to_proto(),
3823                    message_id: message_id.to_proto(),
3824                }],
3825                ..Default::default()
3826            },
3827        )
3828    });
3829    send_notifications(pool, &session.peer, notifications);
3830
3831    Ok(())
3832}
3833
3834/// Delete a channel message
3835async fn remove_channel_message(
3836    request: proto::RemoveChannelMessage,
3837    response: Response<proto::RemoveChannelMessage>,
3838    session: Session,
3839) -> Result<()> {
3840    let channel_id = ChannelId::from_proto(request.channel_id);
3841    let message_id = MessageId::from_proto(request.message_id);
3842    let (connection_ids, existing_notification_ids) = session
3843        .db()
3844        .await
3845        .remove_channel_message(channel_id, message_id, session.user_id())
3846        .await?;
3847
3848    broadcast(
3849        Some(session.connection_id),
3850        connection_ids,
3851        move |connection| {
3852            session.peer.send(connection, request.clone())?;
3853
3854            for notification_id in &existing_notification_ids {
3855                session.peer.send(
3856                    connection,
3857                    proto::DeleteNotification {
3858                        notification_id: (*notification_id).to_proto(),
3859                    },
3860                )?;
3861            }
3862
3863            Ok(())
3864        },
3865    );
3866    response.send(proto::Ack {})?;
3867    Ok(())
3868}
3869
3870async fn update_channel_message(
3871    request: proto::UpdateChannelMessage,
3872    response: Response<proto::UpdateChannelMessage>,
3873    session: Session,
3874) -> Result<()> {
3875    let channel_id = ChannelId::from_proto(request.channel_id);
3876    let message_id = MessageId::from_proto(request.message_id);
3877    let updated_at = OffsetDateTime::now_utc();
3878    let UpdatedChannelMessage {
3879        message_id,
3880        participant_connection_ids,
3881        notifications,
3882        reply_to_message_id,
3883        timestamp,
3884        deleted_mention_notification_ids,
3885        updated_mention_notifications,
3886    } = session
3887        .db()
3888        .await
3889        .update_channel_message(
3890            channel_id,
3891            message_id,
3892            session.user_id(),
3893            request.body.as_str(),
3894            &request.mentions,
3895            updated_at,
3896        )
3897        .await?;
3898
3899    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3900
3901    let message = proto::ChannelMessage {
3902        sender_id: session.user_id().to_proto(),
3903        id: message_id.to_proto(),
3904        body: request.body.clone(),
3905        mentions: request.mentions.clone(),
3906        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3907        nonce: Some(nonce),
3908        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3909        edited_at: Some(updated_at.unix_timestamp() as u64),
3910    };
3911
3912    response.send(proto::Ack {})?;
3913
3914    let pool = &*session.connection_pool().await;
3915    broadcast(
3916        Some(session.connection_id),
3917        participant_connection_ids,
3918        |connection| {
3919            session.peer.send(
3920                connection,
3921                proto::ChannelMessageUpdate {
3922                    channel_id: channel_id.to_proto(),
3923                    message: Some(message.clone()),
3924                },
3925            )?;
3926
3927            for notification_id in &deleted_mention_notification_ids {
3928                session.peer.send(
3929                    connection,
3930                    proto::DeleteNotification {
3931                        notification_id: (*notification_id).to_proto(),
3932                    },
3933                )?;
3934            }
3935
3936            for notification in &updated_mention_notifications {
3937                session.peer.send(
3938                    connection,
3939                    proto::UpdateNotification {
3940                        notification: Some(notification.clone()),
3941                    },
3942                )?;
3943            }
3944
3945            Ok(())
3946        },
3947    );
3948
3949    send_notifications(pool, &session.peer, notifications);
3950
3951    Ok(())
3952}
3953
3954/// Mark a channel message as read
3955async fn acknowledge_channel_message(
3956    request: proto::AckChannelMessage,
3957    session: Session,
3958) -> Result<()> {
3959    let channel_id = ChannelId::from_proto(request.channel_id);
3960    let message_id = MessageId::from_proto(request.message_id);
3961    let notifications = session
3962        .db()
3963        .await
3964        .observe_channel_message(channel_id, session.user_id(), message_id)
3965        .await?;
3966    send_notifications(
3967        &*session.connection_pool().await,
3968        &session.peer,
3969        notifications,
3970    );
3971    Ok(())
3972}
3973
3974/// Mark a buffer version as synced
3975async fn acknowledge_buffer_version(
3976    request: proto::AckBufferOperation,
3977    session: Session,
3978) -> Result<()> {
3979    let buffer_id = BufferId::from_proto(request.buffer_id);
3980    session
3981        .db()
3982        .await
3983        .observe_buffer_version(
3984            buffer_id,
3985            session.user_id(),
3986            request.epoch as i32,
3987            &request.version,
3988        )
3989        .await?;
3990    Ok(())
3991}
3992
3993/// Get a Supermaven API key for the user
3994async fn get_supermaven_api_key(
3995    _request: proto::GetSupermavenApiKey,
3996    response: Response<proto::GetSupermavenApiKey>,
3997    session: Session,
3998) -> Result<()> {
3999    let user_id: String = session.user_id().to_string();
4000    if !session.is_staff() {
4001        return Err(anyhow!("supermaven not enabled for this account"))?;
4002    }
4003
4004    let email = session.email().context("user must have an email")?;
4005
4006    let supermaven_admin_api = session
4007        .supermaven_client
4008        .as_ref()
4009        .context("supermaven not configured")?;
4010
4011    let result = supermaven_admin_api
4012        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4013        .await?;
4014
4015    response.send(proto::GetSupermavenApiKeyResponse {
4016        api_key: result.api_key,
4017    })?;
4018
4019    Ok(())
4020}
4021
4022/// Start receiving chat updates for a channel
4023async fn join_channel_chat(
4024    request: proto::JoinChannelChat,
4025    response: Response<proto::JoinChannelChat>,
4026    session: Session,
4027) -> Result<()> {
4028    let channel_id = ChannelId::from_proto(request.channel_id);
4029
4030    let db = session.db().await;
4031    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4032        .await?;
4033    let messages = db
4034        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4035        .await?;
4036    response.send(proto::JoinChannelChatResponse {
4037        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4038        messages,
4039    })?;
4040    Ok(())
4041}
4042
4043/// Stop receiving chat updates for a channel
4044async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4045    let channel_id = ChannelId::from_proto(request.channel_id);
4046    session
4047        .db()
4048        .await
4049        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4050        .await?;
4051    Ok(())
4052}
4053
4054/// Retrieve the chat history for a channel
4055async fn get_channel_messages(
4056    request: proto::GetChannelMessages,
4057    response: Response<proto::GetChannelMessages>,
4058    session: Session,
4059) -> Result<()> {
4060    let channel_id = ChannelId::from_proto(request.channel_id);
4061    let messages = session
4062        .db()
4063        .await
4064        .get_channel_messages(
4065            channel_id,
4066            session.user_id(),
4067            MESSAGE_COUNT_PER_PAGE,
4068            Some(MessageId::from_proto(request.before_message_id)),
4069        )
4070        .await?;
4071    response.send(proto::GetChannelMessagesResponse {
4072        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4073        messages,
4074    })?;
4075    Ok(())
4076}
4077
4078/// Retrieve specific chat messages
4079async fn get_channel_messages_by_id(
4080    request: proto::GetChannelMessagesById,
4081    response: Response<proto::GetChannelMessagesById>,
4082    session: Session,
4083) -> Result<()> {
4084    let message_ids = request
4085        .message_ids
4086        .iter()
4087        .map(|id| MessageId::from_proto(*id))
4088        .collect::<Vec<_>>();
4089    let messages = session
4090        .db()
4091        .await
4092        .get_channel_messages_by_id(session.user_id(), &message_ids)
4093        .await?;
4094    response.send(proto::GetChannelMessagesResponse {
4095        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4096        messages,
4097    })?;
4098    Ok(())
4099}
4100
4101/// Retrieve the current users notifications
4102async fn get_notifications(
4103    request: proto::GetNotifications,
4104    response: Response<proto::GetNotifications>,
4105    session: Session,
4106) -> Result<()> {
4107    let notifications = session
4108        .db()
4109        .await
4110        .get_notifications(
4111            session.user_id(),
4112            NOTIFICATION_COUNT_PER_PAGE,
4113            request.before_id.map(db::NotificationId::from_proto),
4114        )
4115        .await?;
4116    response.send(proto::GetNotificationsResponse {
4117        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4118        notifications,
4119    })?;
4120    Ok(())
4121}
4122
4123/// Mark notifications as read
4124async fn mark_notification_as_read(
4125    request: proto::MarkNotificationRead,
4126    response: Response<proto::MarkNotificationRead>,
4127    session: Session,
4128) -> Result<()> {
4129    let database = &session.db().await;
4130    let notifications = database
4131        .mark_notification_as_read_by_id(
4132            session.user_id(),
4133            NotificationId::from_proto(request.notification_id),
4134        )
4135        .await?;
4136    send_notifications(
4137        &*session.connection_pool().await,
4138        &session.peer,
4139        notifications,
4140    );
4141    response.send(proto::Ack {})?;
4142    Ok(())
4143}
4144
4145/// Get the current users information
4146async fn get_private_user_info(
4147    _request: proto::GetPrivateUserInfo,
4148    response: Response<proto::GetPrivateUserInfo>,
4149    session: Session,
4150) -> Result<()> {
4151    let db = session.db().await;
4152
4153    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4154    let user = db
4155        .get_user_by_id(session.user_id())
4156        .await?
4157        .context("user not found")?;
4158    let flags = db.get_user_flags(session.user_id()).await?;
4159
4160    response.send(proto::GetPrivateUserInfoResponse {
4161        metrics_id,
4162        staff: user.admin,
4163        flags,
4164        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4165    })?;
4166    Ok(())
4167}
4168
4169/// Accept the terms of service (tos) on behalf of the current user
4170async fn accept_terms_of_service(
4171    _request: proto::AcceptTermsOfService,
4172    response: Response<proto::AcceptTermsOfService>,
4173    session: Session,
4174) -> Result<()> {
4175    let db = session.db().await;
4176
4177    let accepted_tos_at = Utc::now();
4178    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4179        .await?;
4180
4181    response.send(proto::AcceptTermsOfServiceResponse {
4182        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4183    })?;
4184
4185    // When the user accepts the terms of service, we want to refresh their LLM
4186    // token to grant access.
4187    session
4188        .peer
4189        .send(session.connection_id, proto::RefreshLlmToken {})?;
4190
4191    Ok(())
4192}
4193
4194async fn get_llm_api_token(
4195    _request: proto::GetLlmToken,
4196    response: Response<proto::GetLlmToken>,
4197    session: Session,
4198) -> Result<()> {
4199    let db = session.db().await;
4200
4201    let flags = db.get_user_flags(session.user_id()).await?;
4202
4203    let user_id = session.user_id();
4204    let user = db
4205        .get_user_by_id(user_id)
4206        .await?
4207        .with_context(|| format!("user {user_id} not found"))?;
4208
4209    if user.accepted_tos_at.is_none() {
4210        Err(anyhow!("terms of service not accepted"))?
4211    }
4212
4213    let stripe_client = session
4214        .app_state
4215        .stripe_client
4216        .as_ref()
4217        .context("failed to retrieve Stripe client")?;
4218
4219    let stripe_billing = session
4220        .app_state
4221        .stripe_billing
4222        .as_ref()
4223        .context("failed to retrieve Stripe billing object")?;
4224
4225    let billing_customer = if let Some(billing_customer) =
4226        db.get_billing_customer_by_user_id(user.id).await?
4227    {
4228        billing_customer
4229    } else {
4230        let customer_id = stripe_billing
4231            .find_or_create_customer_by_email(user.email_address.as_deref())
4232            .await?;
4233
4234        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4235            .await?
4236            .context("billing customer not found")?
4237    };
4238
4239    let billing_subscription =
4240        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4241            billing_subscription
4242        } else {
4243            let stripe_customer_id =
4244                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4245
4246            let stripe_subscription = stripe_billing
4247                .subscribe_to_zed_free(stripe_customer_id)
4248                .await?;
4249
4250            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4251                billing_customer_id: billing_customer.id,
4252                kind: Some(SubscriptionKind::ZedFree),
4253                stripe_subscription_id: stripe_subscription.id.to_string(),
4254                stripe_subscription_status: stripe_subscription.status.into(),
4255                stripe_cancellation_reason: None,
4256                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4257                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4258            })
4259            .await?
4260        };
4261
4262    let billing_preferences = db.get_billing_preferences(user.id).await?;
4263
4264    let token = LlmTokenClaims::create(
4265        &user,
4266        session.is_staff(),
4267        billing_customer,
4268        billing_preferences,
4269        &flags,
4270        billing_subscription,
4271        session.system_id.clone(),
4272        &session.app_state.config,
4273    )?;
4274    response.send(proto::GetLlmTokenResponse { token })?;
4275    Ok(())
4276}
4277
4278fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4279    let message = match message {
4280        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4281        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4282        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4283        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4284        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4285            code: frame.code.into(),
4286            reason: frame.reason.as_str().to_owned().into(),
4287        })),
4288        // We should never receive a frame while reading the message, according
4289        // to the `tungstenite` maintainers:
4290        //
4291        // > It cannot occur when you read messages from the WebSocket, but it
4292        // > can be used when you want to send the raw frames (e.g. you want to
4293        // > send the frames to the WebSocket without composing the full message first).
4294        // >
4295        // > — https://github.com/snapview/tungstenite-rs/issues/268
4296        TungsteniteMessage::Frame(_) => {
4297            bail!("received an unexpected frame while reading the message")
4298        }
4299    };
4300
4301    Ok(message)
4302}
4303
4304fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4305    match message {
4306        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4307        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4308        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4309        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4310        AxumMessage::Close(frame) => {
4311            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4312                code: frame.code.into(),
4313                reason: frame.reason.as_ref().into(),
4314            }))
4315        }
4316    }
4317}
4318
4319fn notify_membership_updated(
4320    connection_pool: &mut ConnectionPool,
4321    result: MembershipUpdated,
4322    user_id: UserId,
4323    peer: &Peer,
4324) {
4325    for membership in &result.new_channels.channel_memberships {
4326        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4327    }
4328    for channel_id in &result.removed_channels {
4329        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4330    }
4331
4332    let user_channels_update = proto::UpdateUserChannels {
4333        channel_memberships: result
4334            .new_channels
4335            .channel_memberships
4336            .iter()
4337            .map(|cm| proto::ChannelMembership {
4338                channel_id: cm.channel_id.to_proto(),
4339                role: cm.role.into(),
4340            })
4341            .collect(),
4342        ..Default::default()
4343    };
4344
4345    let mut update = build_channels_update(result.new_channels);
4346    update.delete_channels = result
4347        .removed_channels
4348        .into_iter()
4349        .map(|id| id.to_proto())
4350        .collect();
4351    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4352
4353    for connection_id in connection_pool.user_connection_ids(user_id) {
4354        peer.send(connection_id, user_channels_update.clone())
4355            .trace_err();
4356        peer.send(connection_id, update.clone()).trace_err();
4357    }
4358}
4359
4360fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4361    proto::UpdateUserChannels {
4362        channel_memberships: channels
4363            .channel_memberships
4364            .iter()
4365            .map(|m| proto::ChannelMembership {
4366                channel_id: m.channel_id.to_proto(),
4367                role: m.role.into(),
4368            })
4369            .collect(),
4370        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4371        observed_channel_message_id: channels.observed_channel_messages.clone(),
4372    }
4373}
4374
4375fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4376    let mut update = proto::UpdateChannels::default();
4377
4378    for channel in channels.channels {
4379        update.channels.push(channel.to_proto());
4380    }
4381
4382    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4383    update.latest_channel_message_ids = channels.latest_channel_messages;
4384
4385    for (channel_id, participants) in channels.channel_participants {
4386        update
4387            .channel_participants
4388            .push(proto::ChannelParticipants {
4389                channel_id: channel_id.to_proto(),
4390                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4391            });
4392    }
4393
4394    for channel in channels.invited_channels {
4395        update.channel_invitations.push(channel.to_proto());
4396    }
4397
4398    update
4399}
4400
4401fn build_initial_contacts_update(
4402    contacts: Vec<db::Contact>,
4403    pool: &ConnectionPool,
4404) -> proto::UpdateContacts {
4405    let mut update = proto::UpdateContacts::default();
4406
4407    for contact in contacts {
4408        match contact {
4409            db::Contact::Accepted { user_id, busy } => {
4410                update.contacts.push(contact_for_user(user_id, busy, pool));
4411            }
4412            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4413            db::Contact::Incoming { user_id } => {
4414                update
4415                    .incoming_requests
4416                    .push(proto::IncomingContactRequest {
4417                        requester_id: user_id.to_proto(),
4418                    })
4419            }
4420        }
4421    }
4422
4423    update
4424}
4425
4426fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4427    proto::Contact {
4428        user_id: user_id.to_proto(),
4429        online: pool.is_user_online(user_id),
4430        busy,
4431    }
4432}
4433
4434fn room_updated(room: &proto::Room, peer: &Peer) {
4435    broadcast(
4436        None,
4437        room.participants
4438            .iter()
4439            .filter_map(|participant| Some(participant.peer_id?.into())),
4440        |peer_id| {
4441            peer.send(
4442                peer_id,
4443                proto::RoomUpdated {
4444                    room: Some(room.clone()),
4445                },
4446            )
4447        },
4448    );
4449}
4450
4451fn channel_updated(
4452    channel: &db::channel::Model,
4453    room: &proto::Room,
4454    peer: &Peer,
4455    pool: &ConnectionPool,
4456) {
4457    let participants = room
4458        .participants
4459        .iter()
4460        .map(|p| p.user_id)
4461        .collect::<Vec<_>>();
4462
4463    broadcast(
4464        None,
4465        pool.channel_connection_ids(channel.root_id())
4466            .filter_map(|(channel_id, role)| {
4467                role.can_see_channel(channel.visibility)
4468                    .then_some(channel_id)
4469            }),
4470        |peer_id| {
4471            peer.send(
4472                peer_id,
4473                proto::UpdateChannels {
4474                    channel_participants: vec![proto::ChannelParticipants {
4475                        channel_id: channel.id.to_proto(),
4476                        participant_user_ids: participants.clone(),
4477                    }],
4478                    ..Default::default()
4479                },
4480            )
4481        },
4482    );
4483}
4484
4485async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4486    let db = session.db().await;
4487
4488    let contacts = db.get_contacts(user_id).await?;
4489    let busy = db.is_user_busy(user_id).await?;
4490
4491    let pool = session.connection_pool().await;
4492    let updated_contact = contact_for_user(user_id, busy, &pool);
4493    for contact in contacts {
4494        if let db::Contact::Accepted {
4495            user_id: contact_user_id,
4496            ..
4497        } = contact
4498        {
4499            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4500                session
4501                    .peer
4502                    .send(
4503                        contact_conn_id,
4504                        proto::UpdateContacts {
4505                            contacts: vec![updated_contact.clone()],
4506                            remove_contacts: Default::default(),
4507                            incoming_requests: Default::default(),
4508                            remove_incoming_requests: Default::default(),
4509                            outgoing_requests: Default::default(),
4510                            remove_outgoing_requests: Default::default(),
4511                        },
4512                    )
4513                    .trace_err();
4514            }
4515        }
4516    }
4517    Ok(())
4518}
4519
4520async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4521    let mut contacts_to_update = HashSet::default();
4522
4523    let room_id;
4524    let canceled_calls_to_user_ids;
4525    let livekit_room;
4526    let delete_livekit_room;
4527    let room;
4528    let channel;
4529
4530    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4531        contacts_to_update.insert(session.user_id());
4532
4533        for project in left_room.left_projects.values() {
4534            project_left(project, session);
4535        }
4536
4537        room_id = RoomId::from_proto(left_room.room.id);
4538        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4539        livekit_room = mem::take(&mut left_room.room.livekit_room);
4540        delete_livekit_room = left_room.deleted;
4541        room = mem::take(&mut left_room.room);
4542        channel = mem::take(&mut left_room.channel);
4543
4544        room_updated(&room, &session.peer);
4545    } else {
4546        return Ok(());
4547    }
4548
4549    if let Some(channel) = channel {
4550        channel_updated(
4551            &channel,
4552            &room,
4553            &session.peer,
4554            &*session.connection_pool().await,
4555        );
4556    }
4557
4558    {
4559        let pool = session.connection_pool().await;
4560        for canceled_user_id in canceled_calls_to_user_ids {
4561            for connection_id in pool.user_connection_ids(canceled_user_id) {
4562                session
4563                    .peer
4564                    .send(
4565                        connection_id,
4566                        proto::CallCanceled {
4567                            room_id: room_id.to_proto(),
4568                        },
4569                    )
4570                    .trace_err();
4571            }
4572            contacts_to_update.insert(canceled_user_id);
4573        }
4574    }
4575
4576    for contact_user_id in contacts_to_update {
4577        update_user_contacts(contact_user_id, session).await?;
4578    }
4579
4580    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4581        live_kit
4582            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4583            .await
4584            .trace_err();
4585
4586        if delete_livekit_room {
4587            live_kit.delete_room(livekit_room).await.trace_err();
4588        }
4589    }
4590
4591    Ok(())
4592}
4593
4594async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4595    let left_channel_buffers = session
4596        .db()
4597        .await
4598        .leave_channel_buffers(session.connection_id)
4599        .await?;
4600
4601    for left_buffer in left_channel_buffers {
4602        channel_buffer_updated(
4603            session.connection_id,
4604            left_buffer.connections,
4605            &proto::UpdateChannelBufferCollaborators {
4606                channel_id: left_buffer.channel_id.to_proto(),
4607                collaborators: left_buffer.collaborators,
4608            },
4609            &session.peer,
4610        );
4611    }
4612
4613    Ok(())
4614}
4615
4616fn project_left(project: &db::LeftProject, session: &Session) {
4617    for connection_id in &project.connection_ids {
4618        if project.should_unshare {
4619            session
4620                .peer
4621                .send(
4622                    *connection_id,
4623                    proto::UnshareProject {
4624                        project_id: project.id.to_proto(),
4625                    },
4626                )
4627                .trace_err();
4628        } else {
4629            session
4630                .peer
4631                .send(
4632                    *connection_id,
4633                    proto::RemoveProjectCollaborator {
4634                        project_id: project.id.to_proto(),
4635                        peer_id: Some(session.connection_id.into()),
4636                    },
4637                )
4638                .trace_err();
4639        }
4640    }
4641}
4642
4643pub trait ResultExt {
4644    type Ok;
4645
4646    fn trace_err(self) -> Option<Self::Ok>;
4647}
4648
4649impl<T, E> ResultExt for Result<T, E>
4650where
4651    E: std::fmt::Debug,
4652{
4653    type Ok = T;
4654
4655    #[track_caller]
4656    fn trace_err(self) -> Option<T> {
4657        match self {
4658            Ok(value) => Some(value),
4659            Err(error) => {
4660                tracing::error!("{:?}", error);
4661                None
4662            }
4663        }
4664    }
4665}