rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{
   8    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
   9    MIN_ACCOUNT_AGE_FOR_LLM_USE,
  10};
  11use crate::stripe_client::StripeCustomerId;
  12use crate::{
  13    AppState, Error, Result, auth,
  14    db::{
  15        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  16        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  17        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  18        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  19    },
  20    executor::Executor,
  21};
  22use anyhow::{Context as _, anyhow, bail};
  23use async_tungstenite::tungstenite::{
  24    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  25};
  26use axum::{
  27    Extension, Router, TypedHeader,
  28    body::Body,
  29    extract::{
  30        ConnectInfo, WebSocketUpgrade,
  31        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  32    },
  33    headers::{Header, HeaderName},
  34    http::StatusCode,
  35    middleware,
  36    response::IntoResponse,
  37    routing::get,
  38};
  39use chrono::Utc;
  40use collections::{HashMap, HashSet};
  41pub use connection_pool::{ConnectionPool, ZedVersion};
  42use core::fmt::{self, Debug, Formatter};
  43use reqwest_client::ReqwestClient;
  44use rpc::proto::split_repository_update;
  45use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  46
  47use futures::{
  48    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  49    stream::FuturesUnordered,
  50};
  51use prometheus::{IntGauge, register_int_gauge};
  52use rpc::{
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54    proto::{
  55        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  56        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  57    },
  58};
  59use semantic_version::SemanticVersion;
  60use serde::{Serialize, Serializer};
  61use std::{
  62    any::TypeId,
  63    future::Future,
  64    marker::PhantomData,
  65    mem,
  66    net::SocketAddr,
  67    ops::{Deref, DerefMut},
  68    rc::Rc,
  69    sync::{
  70        Arc, OnceLock,
  71        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  72    },
  73    time::{Duration, Instant},
  74};
  75use time::OffsetDateTime;
  76use tokio::sync::{Semaphore, watch};
  77use tower::ServiceBuilder;
  78use tracing::{
  79    Instrument,
  80    field::{self},
  81    info_span, instrument,
  82};
  83
  84pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  85
  86// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  87pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  88
  89const MESSAGE_COUNT_PER_PAGE: usize = 100;
  90const MAX_MESSAGE_LEN: usize = 1024;
  91const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  92const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  93
  94static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  95
  96type MessageHandler =
  97    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  98
  99pub struct ConnectionGuard;
 100
 101impl ConnectionGuard {
 102    pub fn try_acquire() -> Result<Self, ()> {
 103        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 104        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 105            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 106            tracing::error!(
 107                "too many concurrent connections: {}",
 108                current_connections + 1
 109            );
 110            return Err(());
 111        }
 112        Ok(ConnectionGuard)
 113    }
 114}
 115
 116impl Drop for ConnectionGuard {
 117    fn drop(&mut self) {
 118        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 119    }
 120}
 121
 122struct Response<R> {
 123    peer: Arc<Peer>,
 124    receipt: Receipt<R>,
 125    responded: Arc<AtomicBool>,
 126}
 127
 128impl<R: RequestMessage> Response<R> {
 129    fn send(self, payload: R::Response) -> Result<()> {
 130        self.responded.store(true, SeqCst);
 131        self.peer.respond(self.receipt, payload)?;
 132        Ok(())
 133    }
 134}
 135
 136#[derive(Clone, Debug)]
 137pub enum Principal {
 138    User(User),
 139    Impersonated { user: User, admin: User },
 140}
 141
 142impl Principal {
 143    fn user(&self) -> &User {
 144        match self {
 145            Principal::User(user) => user,
 146            Principal::Impersonated { user, .. } => user,
 147        }
 148    }
 149
 150    fn update_span(&self, span: &tracing::Span) {
 151        match &self {
 152            Principal::User(user) => {
 153                span.record("user_id", user.id.0);
 154                span.record("login", &user.github_login);
 155            }
 156            Principal::Impersonated { user, admin } => {
 157                span.record("user_id", user.id.0);
 158                span.record("login", &user.github_login);
 159                span.record("impersonator", &admin.github_login);
 160            }
 161        }
 162    }
 163}
 164
 165#[derive(Clone)]
 166struct Session {
 167    principal: Principal,
 168    connection_id: ConnectionId,
 169    db: Arc<tokio::sync::Mutex<DbHandle>>,
 170    peer: Arc<Peer>,
 171    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 172    app_state: Arc<AppState>,
 173    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 174    /// The GeoIP country code for the user.
 175    #[allow(unused)]
 176    geoip_country_code: Option<String>,
 177    system_id: Option<String>,
 178    _executor: Executor,
 179}
 180
 181impl Session {
 182    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 183        #[cfg(test)]
 184        tokio::task::yield_now().await;
 185        let guard = self.db.lock().await;
 186        #[cfg(test)]
 187        tokio::task::yield_now().await;
 188        guard
 189    }
 190
 191    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 192        #[cfg(test)]
 193        tokio::task::yield_now().await;
 194        let guard = self.connection_pool.lock();
 195        ConnectionPoolGuard {
 196            guard,
 197            _not_send: PhantomData,
 198        }
 199    }
 200
 201    fn is_staff(&self) -> bool {
 202        match &self.principal {
 203            Principal::User(user) => user.admin,
 204            Principal::Impersonated { .. } => true,
 205        }
 206    }
 207
 208    fn user_id(&self) -> UserId {
 209        match &self.principal {
 210            Principal::User(user) => user.id,
 211            Principal::Impersonated { user, .. } => user.id,
 212        }
 213    }
 214
 215    pub fn email(&self) -> Option<String> {
 216        match &self.principal {
 217            Principal::User(user) => user.email_address.clone(),
 218            Principal::Impersonated { user, .. } => user.email_address.clone(),
 219        }
 220    }
 221}
 222
 223impl Debug for Session {
 224    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 225        let mut result = f.debug_struct("Session");
 226        match &self.principal {
 227            Principal::User(user) => {
 228                result.field("user", &user.github_login);
 229            }
 230            Principal::Impersonated { user, admin } => {
 231                result.field("user", &user.github_login);
 232                result.field("impersonator", &admin.github_login);
 233            }
 234        }
 235        result.field("connection_id", &self.connection_id).finish()
 236    }
 237}
 238
 239struct DbHandle(Arc<Database>);
 240
 241impl Deref for DbHandle {
 242    type Target = Database;
 243
 244    fn deref(&self) -> &Self::Target {
 245        self.0.as_ref()
 246    }
 247}
 248
 249pub struct Server {
 250    id: parking_lot::Mutex<ServerId>,
 251    peer: Arc<Peer>,
 252    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 253    app_state: Arc<AppState>,
 254    handlers: HashMap<TypeId, MessageHandler>,
 255    teardown: watch::Sender<bool>,
 256}
 257
 258pub(crate) struct ConnectionPoolGuard<'a> {
 259    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 260    _not_send: PhantomData<Rc<()>>,
 261}
 262
 263#[derive(Serialize)]
 264pub struct ServerSnapshot<'a> {
 265    peer: &'a Peer,
 266    #[serde(serialize_with = "serialize_deref")]
 267    connection_pool: ConnectionPoolGuard<'a>,
 268}
 269
 270pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 271where
 272    S: Serializer,
 273    T: Deref<Target = U>,
 274    U: Serialize,
 275{
 276    Serialize::serialize(value.deref(), serializer)
 277}
 278
 279impl Server {
 280    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 281        let mut server = Self {
 282            id: parking_lot::Mutex::new(id),
 283            peer: Peer::new(id.0 as u32),
 284            app_state: app_state.clone(),
 285            connection_pool: Default::default(),
 286            handlers: Default::default(),
 287            teardown: watch::channel(false).0,
 288        };
 289
 290        server
 291            .add_request_handler(ping)
 292            .add_request_handler(create_room)
 293            .add_request_handler(join_room)
 294            .add_request_handler(rejoin_room)
 295            .add_request_handler(leave_room)
 296            .add_request_handler(set_room_participant_role)
 297            .add_request_handler(call)
 298            .add_request_handler(cancel_call)
 299            .add_message_handler(decline_call)
 300            .add_request_handler(update_participant_location)
 301            .add_request_handler(share_project)
 302            .add_message_handler(unshare_project)
 303            .add_request_handler(join_project)
 304            .add_message_handler(leave_project)
 305            .add_request_handler(update_project)
 306            .add_request_handler(update_worktree)
 307            .add_request_handler(update_repository)
 308            .add_request_handler(remove_repository)
 309            .add_message_handler(start_language_server)
 310            .add_message_handler(update_language_server)
 311            .add_message_handler(update_diagnostic_summary)
 312            .add_message_handler(update_worktree_settings)
 313            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 314            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 316            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 317            .add_request_handler(forward_find_search_candidates_request)
 318            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 319            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 320            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 321            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 322            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 323            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 324            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 325            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 326            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 327            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 328            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 329            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 330            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 331            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 332            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 333            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 334            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 335            .add_request_handler(
 336                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 337            )
 338            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 339            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 340            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 341            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 342            .add_request_handler(
 343                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 344            )
 345            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 346            .add_request_handler(
 347                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 348            )
 349            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 350            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 353            )
 354            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 355            .add_request_handler(
 356                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 357            )
 358            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 359            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 360            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 361            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 362            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 363            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 364            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 365            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 366            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 367            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 368            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 369            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 370            .add_request_handler(
 371                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 372            )
 373            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 374            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 375            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 376            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 377            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 378            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 379            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 380            .add_message_handler(create_buffer_for_peer)
 381            .add_request_handler(update_buffer)
 382            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 383            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 384            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 385            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 386            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 387            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 388            .add_message_handler(
 389                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 390            )
 391            .add_request_handler(get_users)
 392            .add_request_handler(fuzzy_search_users)
 393            .add_request_handler(request_contact)
 394            .add_request_handler(remove_contact)
 395            .add_request_handler(respond_to_contact_request)
 396            .add_message_handler(subscribe_to_channels)
 397            .add_request_handler(create_channel)
 398            .add_request_handler(delete_channel)
 399            .add_request_handler(invite_channel_member)
 400            .add_request_handler(remove_channel_member)
 401            .add_request_handler(set_channel_member_role)
 402            .add_request_handler(set_channel_visibility)
 403            .add_request_handler(rename_channel)
 404            .add_request_handler(join_channel_buffer)
 405            .add_request_handler(leave_channel_buffer)
 406            .add_message_handler(update_channel_buffer)
 407            .add_request_handler(rejoin_channel_buffers)
 408            .add_request_handler(get_channel_members)
 409            .add_request_handler(respond_to_channel_invite)
 410            .add_request_handler(join_channel)
 411            .add_request_handler(join_channel_chat)
 412            .add_message_handler(leave_channel_chat)
 413            .add_request_handler(send_channel_message)
 414            .add_request_handler(remove_channel_message)
 415            .add_request_handler(update_channel_message)
 416            .add_request_handler(get_channel_messages)
 417            .add_request_handler(get_channel_messages_by_id)
 418            .add_request_handler(get_notifications)
 419            .add_request_handler(mark_notification_as_read)
 420            .add_request_handler(move_channel)
 421            .add_request_handler(reorder_channel)
 422            .add_request_handler(follow)
 423            .add_message_handler(unfollow)
 424            .add_message_handler(update_followers)
 425            .add_request_handler(get_private_user_info)
 426            .add_request_handler(get_llm_api_token)
 427            .add_request_handler(accept_terms_of_service)
 428            .add_message_handler(acknowledge_channel_message)
 429            .add_message_handler(acknowledge_buffer_version)
 430            .add_request_handler(get_supermaven_api_key)
 431            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 432            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 433            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 434            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 435            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 436            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 437            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 438            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 439            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 440            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 441            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 442            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 443            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 444            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 445            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 446            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 447            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 448            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 449            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 450            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 451            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 452            .add_message_handler(update_context);
 453
 454        Arc::new(server)
 455    }
 456
 457    pub async fn start(&self) -> Result<()> {
 458        let server_id = *self.id.lock();
 459        let app_state = self.app_state.clone();
 460        let peer = self.peer.clone();
 461        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 462        let pool = self.connection_pool.clone();
 463        let livekit_client = self.app_state.livekit_client.clone();
 464
 465        let span = info_span!("start server");
 466        self.app_state.executor.spawn_detached(
 467            async move {
 468                tracing::info!("waiting for cleanup timeout");
 469                timeout.await;
 470                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 471
 472                app_state
 473                    .db
 474                    .delete_stale_channel_chat_participants(
 475                        &app_state.config.zed_environment,
 476                        server_id,
 477                    )
 478                    .await
 479                    .trace_err();
 480
 481                if let Some((room_ids, channel_ids)) = app_state
 482                    .db
 483                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 484                    .await
 485                    .trace_err()
 486                {
 487                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 488                    tracing::info!(
 489                        stale_channel_buffer_count = channel_ids.len(),
 490                        "retrieved stale channel buffers"
 491                    );
 492
 493                    for channel_id in channel_ids {
 494                        if let Some(refreshed_channel_buffer) = app_state
 495                            .db
 496                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 497                            .await
 498                            .trace_err()
 499                        {
 500                            for connection_id in refreshed_channel_buffer.connection_ids {
 501                                peer.send(
 502                                    connection_id,
 503                                    proto::UpdateChannelBufferCollaborators {
 504                                        channel_id: channel_id.to_proto(),
 505                                        collaborators: refreshed_channel_buffer
 506                                            .collaborators
 507                                            .clone(),
 508                                    },
 509                                )
 510                                .trace_err();
 511                            }
 512                        }
 513                    }
 514
 515                    for room_id in room_ids {
 516                        let mut contacts_to_update = HashSet::default();
 517                        let mut canceled_calls_to_user_ids = Vec::new();
 518                        let mut livekit_room = String::new();
 519                        let mut delete_livekit_room = false;
 520
 521                        if let Some(mut refreshed_room) = app_state
 522                            .db
 523                            .clear_stale_room_participants(room_id, server_id)
 524                            .await
 525                            .trace_err()
 526                        {
 527                            tracing::info!(
 528                                room_id = room_id.0,
 529                                new_participant_count = refreshed_room.room.participants.len(),
 530                                "refreshed room"
 531                            );
 532                            room_updated(&refreshed_room.room, &peer);
 533                            if let Some(channel) = refreshed_room.channel.as_ref() {
 534                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 535                            }
 536                            contacts_to_update
 537                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 538                            contacts_to_update
 539                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 540                            canceled_calls_to_user_ids =
 541                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 542                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 543                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 544                        }
 545
 546                        {
 547                            let pool = pool.lock();
 548                            for canceled_user_id in canceled_calls_to_user_ids {
 549                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 550                                    peer.send(
 551                                        connection_id,
 552                                        proto::CallCanceled {
 553                                            room_id: room_id.to_proto(),
 554                                        },
 555                                    )
 556                                    .trace_err();
 557                                }
 558                            }
 559                        }
 560
 561                        for user_id in contacts_to_update {
 562                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 563                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 564                            if let Some((busy, contacts)) = busy.zip(contacts) {
 565                                let pool = pool.lock();
 566                                let updated_contact = contact_for_user(user_id, busy, &pool);
 567                                for contact in contacts {
 568                                    if let db::Contact::Accepted {
 569                                        user_id: contact_user_id,
 570                                        ..
 571                                    } = contact
 572                                    {
 573                                        for contact_conn_id in
 574                                            pool.user_connection_ids(contact_user_id)
 575                                        {
 576                                            peer.send(
 577                                                contact_conn_id,
 578                                                proto::UpdateContacts {
 579                                                    contacts: vec![updated_contact.clone()],
 580                                                    remove_contacts: Default::default(),
 581                                                    incoming_requests: Default::default(),
 582                                                    remove_incoming_requests: Default::default(),
 583                                                    outgoing_requests: Default::default(),
 584                                                    remove_outgoing_requests: Default::default(),
 585                                                },
 586                                            )
 587                                            .trace_err();
 588                                        }
 589                                    }
 590                                }
 591                            }
 592                        }
 593
 594                        if let Some(live_kit) = livekit_client.as_ref() {
 595                            if delete_livekit_room {
 596                                live_kit.delete_room(livekit_room).await.trace_err();
 597                            }
 598                        }
 599                    }
 600                }
 601
 602                app_state
 603                    .db
 604                    .delete_stale_channel_chat_participants(
 605                        &app_state.config.zed_environment,
 606                        server_id,
 607                    )
 608                    .await
 609                    .trace_err();
 610
 611                app_state
 612                    .db
 613                    .clear_old_worktree_entries(server_id)
 614                    .await
 615                    .trace_err();
 616
 617                app_state
 618                    .db
 619                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 620                    .await
 621                    .trace_err();
 622            }
 623            .instrument(span),
 624        );
 625        Ok(())
 626    }
 627
 628    pub fn teardown(&self) {
 629        self.peer.teardown();
 630        self.connection_pool.lock().reset();
 631        let _ = self.teardown.send(true);
 632    }
 633
 634    #[cfg(test)]
 635    pub fn reset(&self, id: ServerId) {
 636        self.teardown();
 637        *self.id.lock() = id;
 638        self.peer.reset(id.0 as u32);
 639        let _ = self.teardown.send(false);
 640    }
 641
 642    #[cfg(test)]
 643    pub fn id(&self) -> ServerId {
 644        *self.id.lock()
 645    }
 646
 647    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 648    where
 649        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 650        Fut: 'static + Send + Future<Output = Result<()>>,
 651        M: EnvelopedMessage,
 652    {
 653        let prev_handler = self.handlers.insert(
 654            TypeId::of::<M>(),
 655            Box::new(move |envelope, session| {
 656                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 657                let received_at = envelope.received_at;
 658                tracing::info!("message received");
 659                let start_time = Instant::now();
 660                let future = (handler)(*envelope, session);
 661                async move {
 662                    let result = future.await;
 663                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 664                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 665                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 666                    let payload_type = M::NAME;
 667
 668                    match result {
 669                        Err(error) => {
 670                            tracing::error!(
 671                                ?error,
 672                                total_duration_ms,
 673                                processing_duration_ms,
 674                                queue_duration_ms,
 675                                payload_type,
 676                                "error handling message"
 677                            )
 678                        }
 679                        Ok(()) => tracing::info!(
 680                            total_duration_ms,
 681                            processing_duration_ms,
 682                            queue_duration_ms,
 683                            "finished handling message"
 684                        ),
 685                    }
 686                }
 687                .boxed()
 688            }),
 689        );
 690        if prev_handler.is_some() {
 691            panic!("registered a handler for the same message twice");
 692        }
 693        self
 694    }
 695
 696    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 697    where
 698        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 699        Fut: 'static + Send + Future<Output = Result<()>>,
 700        M: EnvelopedMessage,
 701    {
 702        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 703        self
 704    }
 705
 706    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 707    where
 708        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 709        Fut: Send + Future<Output = Result<()>>,
 710        M: RequestMessage,
 711    {
 712        let handler = Arc::new(handler);
 713        self.add_handler(move |envelope, session| {
 714            let receipt = envelope.receipt();
 715            let handler = handler.clone();
 716            async move {
 717                let peer = session.peer.clone();
 718                let responded = Arc::new(AtomicBool::default());
 719                let response = Response {
 720                    peer: peer.clone(),
 721                    responded: responded.clone(),
 722                    receipt,
 723                };
 724                match (handler)(envelope.payload, response, session).await {
 725                    Ok(()) => {
 726                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 727                            Ok(())
 728                        } else {
 729                            Err(anyhow!("handler did not send a response"))?
 730                        }
 731                    }
 732                    Err(error) => {
 733                        let proto_err = match &error {
 734                            Error::Internal(err) => err.to_proto(),
 735                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 736                        };
 737                        peer.respond_with_error(receipt, proto_err)?;
 738                        Err(error)
 739                    }
 740                }
 741            }
 742        })
 743    }
 744
 745    pub fn handle_connection(
 746        self: &Arc<Self>,
 747        connection: Connection,
 748        address: String,
 749        principal: Principal,
 750        zed_version: ZedVersion,
 751        geoip_country_code: Option<String>,
 752        system_id: Option<String>,
 753        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 754        executor: Executor,
 755        connection_guard: Option<ConnectionGuard>,
 756    ) -> impl Future<Output = ()> + use<> {
 757        let this = self.clone();
 758        let span = info_span!("handle connection", %address,
 759            connection_id=field::Empty,
 760            user_id=field::Empty,
 761            login=field::Empty,
 762            impersonator=field::Empty,
 763            geoip_country_code=field::Empty
 764        );
 765        principal.update_span(&span);
 766        if let Some(country_code) = geoip_country_code.as_ref() {
 767            span.record("geoip_country_code", country_code);
 768        }
 769
 770        let mut teardown = self.teardown.subscribe();
 771        async move {
 772            if *teardown.borrow() {
 773                tracing::error!("server is tearing down");
 774                return
 775            }
 776
 777            let (connection_id, handle_io, mut incoming_rx) = this
 778                .peer
 779                .add_connection(connection, {
 780                    let executor = executor.clone();
 781                    move |duration| executor.sleep(duration)
 782                });
 783            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 784
 785            tracing::info!("connection opened");
 786
 787            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 788            let http_client = match ReqwestClient::user_agent(&user_agent) {
 789                Ok(http_client) => Arc::new(http_client),
 790                Err(error) => {
 791                    tracing::error!(?error, "failed to create HTTP client");
 792                    return;
 793                }
 794            };
 795
 796            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 797                    supermaven_admin_api_key.to_string(),
 798                    http_client.clone(),
 799                )));
 800
 801            let session = Session {
 802                principal: principal.clone(),
 803                connection_id,
 804                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 805                peer: this.peer.clone(),
 806                connection_pool: this.connection_pool.clone(),
 807                app_state: this.app_state.clone(),
 808                geoip_country_code,
 809                system_id,
 810                _executor: executor.clone(),
 811                supermaven_client,
 812            };
 813
 814            if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
 815                tracing::error!(?error, "failed to send initial client update");
 816                return;
 817            }
 818            drop(connection_guard);
 819
 820            let handle_io = handle_io.fuse();
 821            futures::pin_mut!(handle_io);
 822
 823            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 824            // This prevents deadlocks when e.g., client A performs a request to client B and
 825            // client B performs a request to client A. If both clients stop processing further
 826            // messages until their respective request completes, they won't have a chance to
 827            // respond to the other client's request and cause a deadlock.
 828            //
 829            // This arrangement ensures we will attempt to process earlier messages first, but fall
 830            // back to processing messages arrived later in the spirit of making progress.
 831            let mut foreground_message_handlers = FuturesUnordered::new();
 832            let concurrent_handlers = Arc::new(Semaphore::new(256));
 833            loop {
 834                let next_message = async {
 835                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 836                    let message = incoming_rx.next().await;
 837                    (permit, message)
 838                }.fuse();
 839                futures::pin_mut!(next_message);
 840                futures::select_biased! {
 841                    _ = teardown.changed().fuse() => return,
 842                    result = handle_io => {
 843                        if let Err(error) = result {
 844                            tracing::error!(?error, "error handling I/O");
 845                        }
 846                        break;
 847                    }
 848                    _ = foreground_message_handlers.next() => {}
 849                    next_message = next_message => {
 850                        let (permit, message) = next_message;
 851                        if let Some(message) = message {
 852                            let type_name = message.payload_type_name();
 853                            // note: we copy all the fields from the parent span so we can query them in the logs.
 854                            // (https://github.com/tokio-rs/tracing/issues/2670).
 855                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 856                                user_id=field::Empty,
 857                                login=field::Empty,
 858                                impersonator=field::Empty,
 859                            );
 860                            principal.update_span(&span);
 861                            let span_enter = span.enter();
 862                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 863                                let is_background = message.is_background();
 864                                let handle_message = (handler)(message, session.clone());
 865                                drop(span_enter);
 866
 867                                let handle_message = async move {
 868                                    handle_message.await;
 869                                    drop(permit);
 870                                }.instrument(span);
 871                                if is_background {
 872                                    executor.spawn_detached(handle_message);
 873                                } else {
 874                                    foreground_message_handlers.push(handle_message);
 875                                }
 876                            } else {
 877                                tracing::error!("no message handler");
 878                            }
 879                        } else {
 880                            tracing::info!("connection closed");
 881                            break;
 882                        }
 883                    }
 884                }
 885            }
 886
 887            drop(foreground_message_handlers);
 888            tracing::info!("signing out");
 889            if let Err(error) = connection_lost(session, teardown, executor).await {
 890                tracing::error!(?error, "error signing out");
 891            }
 892
 893        }.instrument(span)
 894    }
 895
 896    async fn send_initial_client_update(
 897        &self,
 898        connection_id: ConnectionId,
 899        zed_version: ZedVersion,
 900        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 901        session: &Session,
 902    ) -> Result<()> {
 903        self.peer.send(
 904            connection_id,
 905            proto::Hello {
 906                peer_id: Some(connection_id.into()),
 907            },
 908        )?;
 909        tracing::info!("sent hello message");
 910        if let Some(send_connection_id) = send_connection_id.take() {
 911            let _ = send_connection_id.send(connection_id);
 912        }
 913
 914        match &session.principal {
 915            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 916                if !user.connected_once {
 917                    self.peer.send(connection_id, proto::ShowContacts {})?;
 918                    self.app_state
 919                        .db
 920                        .set_user_connected_once(user.id, true)
 921                        .await?;
 922                }
 923
 924                update_user_plan(session).await?;
 925
 926                let contacts = self.app_state.db.get_contacts(user.id).await?;
 927
 928                {
 929                    let mut pool = self.connection_pool.lock();
 930                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 931                    self.peer.send(
 932                        connection_id,
 933                        build_initial_contacts_update(contacts, &pool),
 934                    )?;
 935                }
 936
 937                if should_auto_subscribe_to_channels(zed_version) {
 938                    subscribe_user_to_channels(user.id, session).await?;
 939                }
 940
 941                if let Some(incoming_call) =
 942                    self.app_state.db.incoming_call_for_user(user.id).await?
 943                {
 944                    self.peer.send(connection_id, incoming_call)?;
 945                }
 946
 947                update_user_contacts(user.id, session).await?;
 948            }
 949        }
 950
 951        Ok(())
 952    }
 953
 954    pub async fn invite_code_redeemed(
 955        self: &Arc<Self>,
 956        inviter_id: UserId,
 957        invitee_id: UserId,
 958    ) -> Result<()> {
 959        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 960            if let Some(code) = &user.invite_code {
 961                let pool = self.connection_pool.lock();
 962                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 963                for connection_id in pool.user_connection_ids(inviter_id) {
 964                    self.peer.send(
 965                        connection_id,
 966                        proto::UpdateContacts {
 967                            contacts: vec![invitee_contact.clone()],
 968                            ..Default::default()
 969                        },
 970                    )?;
 971                    self.peer.send(
 972                        connection_id,
 973                        proto::UpdateInviteInfo {
 974                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 975                            count: user.invite_count as u32,
 976                        },
 977                    )?;
 978                }
 979            }
 980        }
 981        Ok(())
 982    }
 983
 984    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 985        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 986            if let Some(invite_code) = &user.invite_code {
 987                let pool = self.connection_pool.lock();
 988                for connection_id in pool.user_connection_ids(user_id) {
 989                    self.peer.send(
 990                        connection_id,
 991                        proto::UpdateInviteInfo {
 992                            url: format!(
 993                                "{}{}",
 994                                self.app_state.config.invite_link_prefix, invite_code
 995                            ),
 996                            count: user.invite_count as u32,
 997                        },
 998                    )?;
 999                }
1000            }
1001        }
1002        Ok(())
1003    }
1004
1005    pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1006        let user = self
1007            .app_state
1008            .db
1009            .get_user_by_id(user_id)
1010            .await?
1011            .context("user not found")?;
1012
1013        let update_user_plan = make_update_user_plan_message(
1014            &user,
1015            user.admin,
1016            &self.app_state.db,
1017            self.app_state.llm_db.clone(),
1018        )
1019        .await?;
1020
1021        let pool = self.connection_pool.lock();
1022        for connection_id in pool.user_connection_ids(user_id) {
1023            self.peer
1024                .send(connection_id, update_user_plan.clone())
1025                .trace_err();
1026        }
1027
1028        Ok(())
1029    }
1030
1031    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1032        let pool = self.connection_pool.lock();
1033        for connection_id in pool.user_connection_ids(user_id) {
1034            self.peer
1035                .send(connection_id, proto::RefreshLlmToken {})
1036                .trace_err();
1037        }
1038    }
1039
1040    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1041        ServerSnapshot {
1042            connection_pool: ConnectionPoolGuard {
1043                guard: self.connection_pool.lock(),
1044                _not_send: PhantomData,
1045            },
1046            peer: &self.peer,
1047        }
1048    }
1049}
1050
1051impl Deref for ConnectionPoolGuard<'_> {
1052    type Target = ConnectionPool;
1053
1054    fn deref(&self) -> &Self::Target {
1055        &self.guard
1056    }
1057}
1058
1059impl DerefMut for ConnectionPoolGuard<'_> {
1060    fn deref_mut(&mut self) -> &mut Self::Target {
1061        &mut self.guard
1062    }
1063}
1064
1065impl Drop for ConnectionPoolGuard<'_> {
1066    fn drop(&mut self) {
1067        #[cfg(test)]
1068        self.check_invariants();
1069    }
1070}
1071
1072fn broadcast<F>(
1073    sender_id: Option<ConnectionId>,
1074    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1075    mut f: F,
1076) where
1077    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1078{
1079    for receiver_id in receiver_ids {
1080        if Some(receiver_id) != sender_id {
1081            if let Err(error) = f(receiver_id) {
1082                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1083            }
1084        }
1085    }
1086}
1087
1088pub struct ProtocolVersion(u32);
1089
1090impl Header for ProtocolVersion {
1091    fn name() -> &'static HeaderName {
1092        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1093        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1094    }
1095
1096    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1097    where
1098        Self: Sized,
1099        I: Iterator<Item = &'i axum::http::HeaderValue>,
1100    {
1101        let version = values
1102            .next()
1103            .ok_or_else(axum::headers::Error::invalid)?
1104            .to_str()
1105            .map_err(|_| axum::headers::Error::invalid())?
1106            .parse()
1107            .map_err(|_| axum::headers::Error::invalid())?;
1108        Ok(Self(version))
1109    }
1110
1111    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1112        values.extend([self.0.to_string().parse().unwrap()]);
1113    }
1114}
1115
1116pub struct AppVersionHeader(SemanticVersion);
1117impl Header for AppVersionHeader {
1118    fn name() -> &'static HeaderName {
1119        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1120        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1121    }
1122
1123    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1124    where
1125        Self: Sized,
1126        I: Iterator<Item = &'i axum::http::HeaderValue>,
1127    {
1128        let version = values
1129            .next()
1130            .ok_or_else(axum::headers::Error::invalid)?
1131            .to_str()
1132            .map_err(|_| axum::headers::Error::invalid())?
1133            .parse()
1134            .map_err(|_| axum::headers::Error::invalid())?;
1135        Ok(Self(version))
1136    }
1137
1138    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1139        values.extend([self.0.to_string().parse().unwrap()]);
1140    }
1141}
1142
1143pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1144    Router::new()
1145        .route("/rpc", get(handle_websocket_request))
1146        .layer(
1147            ServiceBuilder::new()
1148                .layer(Extension(server.app_state.clone()))
1149                .layer(middleware::from_fn(auth::validate_header)),
1150        )
1151        .route("/metrics", get(handle_metrics))
1152        .layer(Extension(server))
1153}
1154
1155pub async fn handle_websocket_request(
1156    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1157    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1158    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1159    Extension(server): Extension<Arc<Server>>,
1160    Extension(principal): Extension<Principal>,
1161    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1162    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1163    ws: WebSocketUpgrade,
1164) -> axum::response::Response {
1165    if protocol_version != rpc::PROTOCOL_VERSION {
1166        return (
1167            StatusCode::UPGRADE_REQUIRED,
1168            "client must be upgraded".to_string(),
1169        )
1170            .into_response();
1171    }
1172
1173    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1174        return (
1175            StatusCode::UPGRADE_REQUIRED,
1176            "no version header found".to_string(),
1177        )
1178            .into_response();
1179    };
1180
1181    if !version.can_collaborate() {
1182        return (
1183            StatusCode::UPGRADE_REQUIRED,
1184            "client must be upgraded".to_string(),
1185        )
1186            .into_response();
1187    }
1188
1189    let socket_address = socket_address.to_string();
1190
1191    // Acquire connection guard before WebSocket upgrade
1192    let connection_guard = match ConnectionGuard::try_acquire() {
1193        Ok(guard) => guard,
1194        Err(()) => {
1195            return (
1196                StatusCode::SERVICE_UNAVAILABLE,
1197                "Too many concurrent connections",
1198            )
1199                .into_response();
1200        }
1201    };
1202
1203    ws.on_upgrade(move |socket| {
1204        let socket = socket
1205            .map_ok(to_tungstenite_message)
1206            .err_into()
1207            .with(|message| async move { to_axum_message(message) });
1208        let connection = Connection::new(Box::pin(socket));
1209        async move {
1210            server
1211                .handle_connection(
1212                    connection,
1213                    socket_address,
1214                    principal,
1215                    version,
1216                    country_code_header.map(|header| header.to_string()),
1217                    system_id_header.map(|header| header.to_string()),
1218                    None,
1219                    Executor::Production,
1220                    Some(connection_guard),
1221                )
1222                .await;
1223        }
1224    })
1225}
1226
1227pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1228    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1229    let connections_metric = CONNECTIONS_METRIC
1230        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1231
1232    let connections = server
1233        .connection_pool
1234        .lock()
1235        .connections()
1236        .filter(|connection| !connection.admin)
1237        .count();
1238    connections_metric.set(connections as _);
1239
1240    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1241    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1242        register_int_gauge!(
1243            "shared_projects",
1244            "number of open projects with one or more guests"
1245        )
1246        .unwrap()
1247    });
1248
1249    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1250    shared_projects_metric.set(shared_projects as _);
1251
1252    let encoder = prometheus::TextEncoder::new();
1253    let metric_families = prometheus::gather();
1254    let encoded_metrics = encoder
1255        .encode_to_string(&metric_families)
1256        .map_err(|err| anyhow!("{err}"))?;
1257    Ok(encoded_metrics)
1258}
1259
1260#[instrument(err, skip(executor))]
1261async fn connection_lost(
1262    session: Session,
1263    mut teardown: watch::Receiver<bool>,
1264    executor: Executor,
1265) -> Result<()> {
1266    session.peer.disconnect(session.connection_id);
1267    session
1268        .connection_pool()
1269        .await
1270        .remove_connection(session.connection_id)?;
1271
1272    session
1273        .db()
1274        .await
1275        .connection_lost(session.connection_id)
1276        .await
1277        .trace_err();
1278
1279    futures::select_biased! {
1280        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1281
1282            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1283            leave_room_for_session(&session, session.connection_id).await.trace_err();
1284            leave_channel_buffers_for_session(&session)
1285                .await
1286                .trace_err();
1287
1288            if !session
1289                .connection_pool()
1290                .await
1291                .is_user_online(session.user_id())
1292            {
1293                let db = session.db().await;
1294                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1295                    room_updated(&room, &session.peer);
1296                }
1297            }
1298
1299            update_user_contacts(session.user_id(), &session).await?;
1300        },
1301        _ = teardown.changed().fuse() => {}
1302    }
1303
1304    Ok(())
1305}
1306
1307/// Acknowledges a ping from a client, used to keep the connection alive.
1308async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1309    response.send(proto::Ack {})?;
1310    Ok(())
1311}
1312
1313/// Creates a new room for calling (outside of channels)
1314async fn create_room(
1315    _request: proto::CreateRoom,
1316    response: Response<proto::CreateRoom>,
1317    session: Session,
1318) -> Result<()> {
1319    let livekit_room = nanoid::nanoid!(30);
1320
1321    let live_kit_connection_info = util::maybe!(async {
1322        let live_kit = session.app_state.livekit_client.as_ref();
1323        let live_kit = live_kit?;
1324        let user_id = session.user_id().to_string();
1325
1326        let token = live_kit
1327            .room_token(&livekit_room, &user_id.to_string())
1328            .trace_err()?;
1329
1330        Some(proto::LiveKitConnectionInfo {
1331            server_url: live_kit.url().into(),
1332            token,
1333            can_publish: true,
1334        })
1335    })
1336    .await;
1337
1338    let room = session
1339        .db()
1340        .await
1341        .create_room(session.user_id(), session.connection_id, &livekit_room)
1342        .await?;
1343
1344    response.send(proto::CreateRoomResponse {
1345        room: Some(room.clone()),
1346        live_kit_connection_info,
1347    })?;
1348
1349    update_user_contacts(session.user_id(), &session).await?;
1350    Ok(())
1351}
1352
1353/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1354async fn join_room(
1355    request: proto::JoinRoom,
1356    response: Response<proto::JoinRoom>,
1357    session: Session,
1358) -> Result<()> {
1359    let room_id = RoomId::from_proto(request.id);
1360
1361    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1362
1363    if let Some(channel_id) = channel_id {
1364        return join_channel_internal(channel_id, Box::new(response), session).await;
1365    }
1366
1367    let joined_room = {
1368        let room = session
1369            .db()
1370            .await
1371            .join_room(room_id, session.user_id(), session.connection_id)
1372            .await?;
1373        room_updated(&room.room, &session.peer);
1374        room.into_inner()
1375    };
1376
1377    for connection_id in session
1378        .connection_pool()
1379        .await
1380        .user_connection_ids(session.user_id())
1381    {
1382        session
1383            .peer
1384            .send(
1385                connection_id,
1386                proto::CallCanceled {
1387                    room_id: room_id.to_proto(),
1388                },
1389            )
1390            .trace_err();
1391    }
1392
1393    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1394    {
1395        live_kit
1396            .room_token(
1397                &joined_room.room.livekit_room,
1398                &session.user_id().to_string(),
1399            )
1400            .trace_err()
1401            .map(|token| proto::LiveKitConnectionInfo {
1402                server_url: live_kit.url().into(),
1403                token,
1404                can_publish: true,
1405            })
1406    } else {
1407        None
1408    };
1409
1410    response.send(proto::JoinRoomResponse {
1411        room: Some(joined_room.room),
1412        channel_id: None,
1413        live_kit_connection_info,
1414    })?;
1415
1416    update_user_contacts(session.user_id(), &session).await?;
1417    Ok(())
1418}
1419
1420/// Rejoin room is used to reconnect to a room after connection errors.
1421async fn rejoin_room(
1422    request: proto::RejoinRoom,
1423    response: Response<proto::RejoinRoom>,
1424    session: Session,
1425) -> Result<()> {
1426    let room;
1427    let channel;
1428    {
1429        let mut rejoined_room = session
1430            .db()
1431            .await
1432            .rejoin_room(request, session.user_id(), session.connection_id)
1433            .await?;
1434
1435        response.send(proto::RejoinRoomResponse {
1436            room: Some(rejoined_room.room.clone()),
1437            reshared_projects: rejoined_room
1438                .reshared_projects
1439                .iter()
1440                .map(|project| proto::ResharedProject {
1441                    id: project.id.to_proto(),
1442                    collaborators: project
1443                        .collaborators
1444                        .iter()
1445                        .map(|collaborator| collaborator.to_proto())
1446                        .collect(),
1447                })
1448                .collect(),
1449            rejoined_projects: rejoined_room
1450                .rejoined_projects
1451                .iter()
1452                .map(|rejoined_project| rejoined_project.to_proto())
1453                .collect(),
1454        })?;
1455        room_updated(&rejoined_room.room, &session.peer);
1456
1457        for project in &rejoined_room.reshared_projects {
1458            for collaborator in &project.collaborators {
1459                session
1460                    .peer
1461                    .send(
1462                        collaborator.connection_id,
1463                        proto::UpdateProjectCollaborator {
1464                            project_id: project.id.to_proto(),
1465                            old_peer_id: Some(project.old_connection_id.into()),
1466                            new_peer_id: Some(session.connection_id.into()),
1467                        },
1468                    )
1469                    .trace_err();
1470            }
1471
1472            broadcast(
1473                Some(session.connection_id),
1474                project
1475                    .collaborators
1476                    .iter()
1477                    .map(|collaborator| collaborator.connection_id),
1478                |connection_id| {
1479                    session.peer.forward_send(
1480                        session.connection_id,
1481                        connection_id,
1482                        proto::UpdateProject {
1483                            project_id: project.id.to_proto(),
1484                            worktrees: project.worktrees.clone(),
1485                        },
1486                    )
1487                },
1488            );
1489        }
1490
1491        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1492
1493        let rejoined_room = rejoined_room.into_inner();
1494
1495        room = rejoined_room.room;
1496        channel = rejoined_room.channel;
1497    }
1498
1499    if let Some(channel) = channel {
1500        channel_updated(
1501            &channel,
1502            &room,
1503            &session.peer,
1504            &*session.connection_pool().await,
1505        );
1506    }
1507
1508    update_user_contacts(session.user_id(), &session).await?;
1509    Ok(())
1510}
1511
1512fn notify_rejoined_projects(
1513    rejoined_projects: &mut Vec<RejoinedProject>,
1514    session: &Session,
1515) -> Result<()> {
1516    for project in rejoined_projects.iter() {
1517        for collaborator in &project.collaborators {
1518            session
1519                .peer
1520                .send(
1521                    collaborator.connection_id,
1522                    proto::UpdateProjectCollaborator {
1523                        project_id: project.id.to_proto(),
1524                        old_peer_id: Some(project.old_connection_id.into()),
1525                        new_peer_id: Some(session.connection_id.into()),
1526                    },
1527                )
1528                .trace_err();
1529        }
1530    }
1531
1532    for project in rejoined_projects {
1533        for worktree in mem::take(&mut project.worktrees) {
1534            // Stream this worktree's entries.
1535            let message = proto::UpdateWorktree {
1536                project_id: project.id.to_proto(),
1537                worktree_id: worktree.id,
1538                abs_path: worktree.abs_path.clone(),
1539                root_name: worktree.root_name,
1540                updated_entries: worktree.updated_entries,
1541                removed_entries: worktree.removed_entries,
1542                scan_id: worktree.scan_id,
1543                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1544                updated_repositories: worktree.updated_repositories,
1545                removed_repositories: worktree.removed_repositories,
1546            };
1547            for update in proto::split_worktree_update(message) {
1548                session.peer.send(session.connection_id, update)?;
1549            }
1550
1551            // Stream this worktree's diagnostics.
1552            for summary in worktree.diagnostic_summaries {
1553                session.peer.send(
1554                    session.connection_id,
1555                    proto::UpdateDiagnosticSummary {
1556                        project_id: project.id.to_proto(),
1557                        worktree_id: worktree.id,
1558                        summary: Some(summary),
1559                    },
1560                )?;
1561            }
1562
1563            for settings_file in worktree.settings_files {
1564                session.peer.send(
1565                    session.connection_id,
1566                    proto::UpdateWorktreeSettings {
1567                        project_id: project.id.to_proto(),
1568                        worktree_id: worktree.id,
1569                        path: settings_file.path,
1570                        content: Some(settings_file.content),
1571                        kind: Some(settings_file.kind.to_proto().into()),
1572                    },
1573                )?;
1574            }
1575        }
1576
1577        for repository in mem::take(&mut project.updated_repositories) {
1578            for update in split_repository_update(repository) {
1579                session.peer.send(session.connection_id, update)?;
1580            }
1581        }
1582
1583        for id in mem::take(&mut project.removed_repositories) {
1584            session.peer.send(
1585                session.connection_id,
1586                proto::RemoveRepository {
1587                    project_id: project.id.to_proto(),
1588                    id,
1589                },
1590            )?;
1591        }
1592    }
1593
1594    Ok(())
1595}
1596
1597/// leave room disconnects from the room.
1598async fn leave_room(
1599    _: proto::LeaveRoom,
1600    response: Response<proto::LeaveRoom>,
1601    session: Session,
1602) -> Result<()> {
1603    leave_room_for_session(&session, session.connection_id).await?;
1604    response.send(proto::Ack {})?;
1605    Ok(())
1606}
1607
1608/// Updates the permissions of someone else in the room.
1609async fn set_room_participant_role(
1610    request: proto::SetRoomParticipantRole,
1611    response: Response<proto::SetRoomParticipantRole>,
1612    session: Session,
1613) -> Result<()> {
1614    let user_id = UserId::from_proto(request.user_id);
1615    let role = ChannelRole::from(request.role());
1616
1617    let (livekit_room, can_publish) = {
1618        let room = session
1619            .db()
1620            .await
1621            .set_room_participant_role(
1622                session.user_id(),
1623                RoomId::from_proto(request.room_id),
1624                user_id,
1625                role,
1626            )
1627            .await?;
1628
1629        let livekit_room = room.livekit_room.clone();
1630        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1631        room_updated(&room, &session.peer);
1632        (livekit_room, can_publish)
1633    };
1634
1635    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1636        live_kit
1637            .update_participant(
1638                livekit_room.clone(),
1639                request.user_id.to_string(),
1640                livekit_api::proto::ParticipantPermission {
1641                    can_subscribe: true,
1642                    can_publish,
1643                    can_publish_data: can_publish,
1644                    hidden: false,
1645                    recorder: false,
1646                },
1647            )
1648            .await
1649            .trace_err();
1650    }
1651
1652    response.send(proto::Ack {})?;
1653    Ok(())
1654}
1655
1656/// Call someone else into the current room
1657async fn call(
1658    request: proto::Call,
1659    response: Response<proto::Call>,
1660    session: Session,
1661) -> Result<()> {
1662    let room_id = RoomId::from_proto(request.room_id);
1663    let calling_user_id = session.user_id();
1664    let calling_connection_id = session.connection_id;
1665    let called_user_id = UserId::from_proto(request.called_user_id);
1666    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1667    if !session
1668        .db()
1669        .await
1670        .has_contact(calling_user_id, called_user_id)
1671        .await?
1672    {
1673        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1674    }
1675
1676    let incoming_call = {
1677        let (room, incoming_call) = &mut *session
1678            .db()
1679            .await
1680            .call(
1681                room_id,
1682                calling_user_id,
1683                calling_connection_id,
1684                called_user_id,
1685                initial_project_id,
1686            )
1687            .await?;
1688        room_updated(room, &session.peer);
1689        mem::take(incoming_call)
1690    };
1691    update_user_contacts(called_user_id, &session).await?;
1692
1693    let mut calls = session
1694        .connection_pool()
1695        .await
1696        .user_connection_ids(called_user_id)
1697        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1698        .collect::<FuturesUnordered<_>>();
1699
1700    while let Some(call_response) = calls.next().await {
1701        match call_response.as_ref() {
1702            Ok(_) => {
1703                response.send(proto::Ack {})?;
1704                return Ok(());
1705            }
1706            Err(_) => {
1707                call_response.trace_err();
1708            }
1709        }
1710    }
1711
1712    {
1713        let room = session
1714            .db()
1715            .await
1716            .call_failed(room_id, called_user_id)
1717            .await?;
1718        room_updated(&room, &session.peer);
1719    }
1720    update_user_contacts(called_user_id, &session).await?;
1721
1722    Err(anyhow!("failed to ring user"))?
1723}
1724
1725/// Cancel an outgoing call.
1726async fn cancel_call(
1727    request: proto::CancelCall,
1728    response: Response<proto::CancelCall>,
1729    session: Session,
1730) -> Result<()> {
1731    let called_user_id = UserId::from_proto(request.called_user_id);
1732    let room_id = RoomId::from_proto(request.room_id);
1733    {
1734        let room = session
1735            .db()
1736            .await
1737            .cancel_call(room_id, session.connection_id, called_user_id)
1738            .await?;
1739        room_updated(&room, &session.peer);
1740    }
1741
1742    for connection_id in session
1743        .connection_pool()
1744        .await
1745        .user_connection_ids(called_user_id)
1746    {
1747        session
1748            .peer
1749            .send(
1750                connection_id,
1751                proto::CallCanceled {
1752                    room_id: room_id.to_proto(),
1753                },
1754            )
1755            .trace_err();
1756    }
1757    response.send(proto::Ack {})?;
1758
1759    update_user_contacts(called_user_id, &session).await?;
1760    Ok(())
1761}
1762
1763/// Decline an incoming call.
1764async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1765    let room_id = RoomId::from_proto(message.room_id);
1766    {
1767        let room = session
1768            .db()
1769            .await
1770            .decline_call(Some(room_id), session.user_id())
1771            .await?
1772            .context("declining call")?;
1773        room_updated(&room, &session.peer);
1774    }
1775
1776    for connection_id in session
1777        .connection_pool()
1778        .await
1779        .user_connection_ids(session.user_id())
1780    {
1781        session
1782            .peer
1783            .send(
1784                connection_id,
1785                proto::CallCanceled {
1786                    room_id: room_id.to_proto(),
1787                },
1788            )
1789            .trace_err();
1790    }
1791    update_user_contacts(session.user_id(), &session).await?;
1792    Ok(())
1793}
1794
1795/// Updates other participants in the room with your current location.
1796async fn update_participant_location(
1797    request: proto::UpdateParticipantLocation,
1798    response: Response<proto::UpdateParticipantLocation>,
1799    session: Session,
1800) -> Result<()> {
1801    let room_id = RoomId::from_proto(request.room_id);
1802    let location = request.location.context("invalid location")?;
1803
1804    let db = session.db().await;
1805    let room = db
1806        .update_room_participant_location(room_id, session.connection_id, location)
1807        .await?;
1808
1809    room_updated(&room, &session.peer);
1810    response.send(proto::Ack {})?;
1811    Ok(())
1812}
1813
1814/// Share a project into the room.
1815async fn share_project(
1816    request: proto::ShareProject,
1817    response: Response<proto::ShareProject>,
1818    session: Session,
1819) -> Result<()> {
1820    let (project_id, room) = &*session
1821        .db()
1822        .await
1823        .share_project(
1824            RoomId::from_proto(request.room_id),
1825            session.connection_id,
1826            &request.worktrees,
1827            request.is_ssh_project,
1828        )
1829        .await?;
1830    response.send(proto::ShareProjectResponse {
1831        project_id: project_id.to_proto(),
1832    })?;
1833    room_updated(room, &session.peer);
1834
1835    Ok(())
1836}
1837
1838/// Unshare a project from the room.
1839async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1840    let project_id = ProjectId::from_proto(message.project_id);
1841    unshare_project_internal(project_id, session.connection_id, &session).await
1842}
1843
1844async fn unshare_project_internal(
1845    project_id: ProjectId,
1846    connection_id: ConnectionId,
1847    session: &Session,
1848) -> Result<()> {
1849    let delete = {
1850        let room_guard = session
1851            .db()
1852            .await
1853            .unshare_project(project_id, connection_id)
1854            .await?;
1855
1856        let (delete, room, guest_connection_ids) = &*room_guard;
1857
1858        let message = proto::UnshareProject {
1859            project_id: project_id.to_proto(),
1860        };
1861
1862        broadcast(
1863            Some(connection_id),
1864            guest_connection_ids.iter().copied(),
1865            |conn_id| session.peer.send(conn_id, message.clone()),
1866        );
1867        if let Some(room) = room {
1868            room_updated(room, &session.peer);
1869        }
1870
1871        *delete
1872    };
1873
1874    if delete {
1875        let db = session.db().await;
1876        db.delete_project(project_id).await?;
1877    }
1878
1879    Ok(())
1880}
1881
1882/// Join someone elses shared project.
1883async fn join_project(
1884    request: proto::JoinProject,
1885    response: Response<proto::JoinProject>,
1886    session: Session,
1887) -> Result<()> {
1888    let project_id = ProjectId::from_proto(request.project_id);
1889
1890    tracing::info!(%project_id, "join project");
1891
1892    let db = session.db().await;
1893    let (project, replica_id) = &mut *db
1894        .join_project(
1895            project_id,
1896            session.connection_id,
1897            session.user_id(),
1898            request.committer_name.clone(),
1899            request.committer_email.clone(),
1900        )
1901        .await?;
1902    drop(db);
1903    tracing::info!(%project_id, "join remote project");
1904    let collaborators = project
1905        .collaborators
1906        .iter()
1907        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1908        .map(|collaborator| collaborator.to_proto())
1909        .collect::<Vec<_>>();
1910    let project_id = project.id;
1911    let guest_user_id = session.user_id();
1912
1913    let worktrees = project
1914        .worktrees
1915        .iter()
1916        .map(|(id, worktree)| proto::WorktreeMetadata {
1917            id: *id,
1918            root_name: worktree.root_name.clone(),
1919            visible: worktree.visible,
1920            abs_path: worktree.abs_path.clone(),
1921        })
1922        .collect::<Vec<_>>();
1923
1924    let add_project_collaborator = proto::AddProjectCollaborator {
1925        project_id: project_id.to_proto(),
1926        collaborator: Some(proto::Collaborator {
1927            peer_id: Some(session.connection_id.into()),
1928            replica_id: replica_id.0 as u32,
1929            user_id: guest_user_id.to_proto(),
1930            is_host: false,
1931            committer_name: request.committer_name.clone(),
1932            committer_email: request.committer_email.clone(),
1933        }),
1934    };
1935
1936    for collaborator in &collaborators {
1937        session
1938            .peer
1939            .send(
1940                collaborator.peer_id.unwrap().into(),
1941                add_project_collaborator.clone(),
1942            )
1943            .trace_err();
1944    }
1945
1946    // First, we send the metadata associated with each worktree.
1947    response.send(proto::JoinProjectResponse {
1948        project_id: project.id.0 as u64,
1949        worktrees: worktrees.clone(),
1950        replica_id: replica_id.0 as u32,
1951        collaborators: collaborators.clone(),
1952        language_servers: project.language_servers.clone(),
1953        role: project.role.into(),
1954    })?;
1955
1956    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1957        // Stream this worktree's entries.
1958        let message = proto::UpdateWorktree {
1959            project_id: project_id.to_proto(),
1960            worktree_id,
1961            abs_path: worktree.abs_path.clone(),
1962            root_name: worktree.root_name,
1963            updated_entries: worktree.entries,
1964            removed_entries: Default::default(),
1965            scan_id: worktree.scan_id,
1966            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1967            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1968            removed_repositories: Default::default(),
1969        };
1970        for update in proto::split_worktree_update(message) {
1971            session.peer.send(session.connection_id, update.clone())?;
1972        }
1973
1974        // Stream this worktree's diagnostics.
1975        for summary in worktree.diagnostic_summaries {
1976            session.peer.send(
1977                session.connection_id,
1978                proto::UpdateDiagnosticSummary {
1979                    project_id: project_id.to_proto(),
1980                    worktree_id: worktree.id,
1981                    summary: Some(summary),
1982                },
1983            )?;
1984        }
1985
1986        for settings_file in worktree.settings_files {
1987            session.peer.send(
1988                session.connection_id,
1989                proto::UpdateWorktreeSettings {
1990                    project_id: project_id.to_proto(),
1991                    worktree_id: worktree.id,
1992                    path: settings_file.path,
1993                    content: Some(settings_file.content),
1994                    kind: Some(settings_file.kind.to_proto() as i32),
1995                },
1996            )?;
1997        }
1998    }
1999
2000    for repository in mem::take(&mut project.repositories) {
2001        for update in split_repository_update(repository) {
2002            session.peer.send(session.connection_id, update)?;
2003        }
2004    }
2005
2006    for language_server in &project.language_servers {
2007        session.peer.send(
2008            session.connection_id,
2009            proto::UpdateLanguageServer {
2010                project_id: project_id.to_proto(),
2011                server_name: Some(language_server.name.clone()),
2012                language_server_id: language_server.id,
2013                variant: Some(
2014                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2015                        proto::LspDiskBasedDiagnosticsUpdated {},
2016                    ),
2017                ),
2018            },
2019        )?;
2020    }
2021
2022    Ok(())
2023}
2024
2025/// Leave someone elses shared project.
2026async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2027    let sender_id = session.connection_id;
2028    let project_id = ProjectId::from_proto(request.project_id);
2029    let db = session.db().await;
2030
2031    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2032    tracing::info!(
2033        %project_id,
2034        "leave project"
2035    );
2036
2037    project_left(project, &session);
2038    if let Some(room) = room {
2039        room_updated(room, &session.peer);
2040    }
2041
2042    Ok(())
2043}
2044
2045/// Updates other participants with changes to the project
2046async fn update_project(
2047    request: proto::UpdateProject,
2048    response: Response<proto::UpdateProject>,
2049    session: Session,
2050) -> Result<()> {
2051    let project_id = ProjectId::from_proto(request.project_id);
2052    let (room, guest_connection_ids) = &*session
2053        .db()
2054        .await
2055        .update_project(project_id, session.connection_id, &request.worktrees)
2056        .await?;
2057    broadcast(
2058        Some(session.connection_id),
2059        guest_connection_ids.iter().copied(),
2060        |connection_id| {
2061            session
2062                .peer
2063                .forward_send(session.connection_id, connection_id, request.clone())
2064        },
2065    );
2066    if let Some(room) = room {
2067        room_updated(room, &session.peer);
2068    }
2069    response.send(proto::Ack {})?;
2070
2071    Ok(())
2072}
2073
2074/// Updates other participants with changes to the worktree
2075async fn update_worktree(
2076    request: proto::UpdateWorktree,
2077    response: Response<proto::UpdateWorktree>,
2078    session: Session,
2079) -> Result<()> {
2080    let guest_connection_ids = session
2081        .db()
2082        .await
2083        .update_worktree(&request, session.connection_id)
2084        .await?;
2085
2086    broadcast(
2087        Some(session.connection_id),
2088        guest_connection_ids.iter().copied(),
2089        |connection_id| {
2090            session
2091                .peer
2092                .forward_send(session.connection_id, connection_id, request.clone())
2093        },
2094    );
2095    response.send(proto::Ack {})?;
2096    Ok(())
2097}
2098
2099async fn update_repository(
2100    request: proto::UpdateRepository,
2101    response: Response<proto::UpdateRepository>,
2102    session: Session,
2103) -> Result<()> {
2104    let guest_connection_ids = session
2105        .db()
2106        .await
2107        .update_repository(&request, session.connection_id)
2108        .await?;
2109
2110    broadcast(
2111        Some(session.connection_id),
2112        guest_connection_ids.iter().copied(),
2113        |connection_id| {
2114            session
2115                .peer
2116                .forward_send(session.connection_id, connection_id, request.clone())
2117        },
2118    );
2119    response.send(proto::Ack {})?;
2120    Ok(())
2121}
2122
2123async fn remove_repository(
2124    request: proto::RemoveRepository,
2125    response: Response<proto::RemoveRepository>,
2126    session: Session,
2127) -> Result<()> {
2128    let guest_connection_ids = session
2129        .db()
2130        .await
2131        .remove_repository(&request, session.connection_id)
2132        .await?;
2133
2134    broadcast(
2135        Some(session.connection_id),
2136        guest_connection_ids.iter().copied(),
2137        |connection_id| {
2138            session
2139                .peer
2140                .forward_send(session.connection_id, connection_id, request.clone())
2141        },
2142    );
2143    response.send(proto::Ack {})?;
2144    Ok(())
2145}
2146
2147/// Updates other participants with changes to the diagnostics
2148async fn update_diagnostic_summary(
2149    message: proto::UpdateDiagnosticSummary,
2150    session: Session,
2151) -> Result<()> {
2152    let guest_connection_ids = session
2153        .db()
2154        .await
2155        .update_diagnostic_summary(&message, session.connection_id)
2156        .await?;
2157
2158    broadcast(
2159        Some(session.connection_id),
2160        guest_connection_ids.iter().copied(),
2161        |connection_id| {
2162            session
2163                .peer
2164                .forward_send(session.connection_id, connection_id, message.clone())
2165        },
2166    );
2167
2168    Ok(())
2169}
2170
2171/// Updates other participants with changes to the worktree settings
2172async fn update_worktree_settings(
2173    message: proto::UpdateWorktreeSettings,
2174    session: Session,
2175) -> Result<()> {
2176    let guest_connection_ids = session
2177        .db()
2178        .await
2179        .update_worktree_settings(&message, session.connection_id)
2180        .await?;
2181
2182    broadcast(
2183        Some(session.connection_id),
2184        guest_connection_ids.iter().copied(),
2185        |connection_id| {
2186            session
2187                .peer
2188                .forward_send(session.connection_id, connection_id, message.clone())
2189        },
2190    );
2191
2192    Ok(())
2193}
2194
2195/// Notify other participants that a language server has started.
2196async fn start_language_server(
2197    request: proto::StartLanguageServer,
2198    session: Session,
2199) -> Result<()> {
2200    let guest_connection_ids = session
2201        .db()
2202        .await
2203        .start_language_server(&request, session.connection_id)
2204        .await?;
2205
2206    broadcast(
2207        Some(session.connection_id),
2208        guest_connection_ids.iter().copied(),
2209        |connection_id| {
2210            session
2211                .peer
2212                .forward_send(session.connection_id, connection_id, request.clone())
2213        },
2214    );
2215    Ok(())
2216}
2217
2218/// Notify other participants that a language server has changed.
2219async fn update_language_server(
2220    request: proto::UpdateLanguageServer,
2221    session: Session,
2222) -> Result<()> {
2223    let project_id = ProjectId::from_proto(request.project_id);
2224    let project_connection_ids = session
2225        .db()
2226        .await
2227        .project_connection_ids(project_id, session.connection_id, true)
2228        .await?;
2229    broadcast(
2230        Some(session.connection_id),
2231        project_connection_ids.iter().copied(),
2232        |connection_id| {
2233            session
2234                .peer
2235                .forward_send(session.connection_id, connection_id, request.clone())
2236        },
2237    );
2238    Ok(())
2239}
2240
2241/// forward a project request to the host. These requests should be read only
2242/// as guests are allowed to send them.
2243async fn forward_read_only_project_request<T>(
2244    request: T,
2245    response: Response<T>,
2246    session: Session,
2247) -> Result<()>
2248where
2249    T: EntityMessage + RequestMessage,
2250{
2251    let project_id = ProjectId::from_proto(request.remote_entity_id());
2252    let host_connection_id = session
2253        .db()
2254        .await
2255        .host_for_read_only_project_request(project_id, session.connection_id)
2256        .await?;
2257    let payload = session
2258        .peer
2259        .forward_request(session.connection_id, host_connection_id, request)
2260        .await?;
2261    response.send(payload)?;
2262    Ok(())
2263}
2264
2265async fn forward_find_search_candidates_request(
2266    request: proto::FindSearchCandidates,
2267    response: Response<proto::FindSearchCandidates>,
2268    session: Session,
2269) -> Result<()> {
2270    let project_id = ProjectId::from_proto(request.remote_entity_id());
2271    let host_connection_id = session
2272        .db()
2273        .await
2274        .host_for_read_only_project_request(project_id, session.connection_id)
2275        .await?;
2276    let payload = session
2277        .peer
2278        .forward_request(session.connection_id, host_connection_id, request)
2279        .await?;
2280    response.send(payload)?;
2281    Ok(())
2282}
2283
2284/// forward a project request to the host. These requests are disallowed
2285/// for guests.
2286async fn forward_mutating_project_request<T>(
2287    request: T,
2288    response: Response<T>,
2289    session: Session,
2290) -> Result<()>
2291where
2292    T: EntityMessage + RequestMessage,
2293{
2294    let project_id = ProjectId::from_proto(request.remote_entity_id());
2295
2296    let host_connection_id = session
2297        .db()
2298        .await
2299        .host_for_mutating_project_request(project_id, session.connection_id)
2300        .await?;
2301    let payload = session
2302        .peer
2303        .forward_request(session.connection_id, host_connection_id, request)
2304        .await?;
2305    response.send(payload)?;
2306    Ok(())
2307}
2308
2309/// Notify other participants that a new buffer has been created
2310async fn create_buffer_for_peer(
2311    request: proto::CreateBufferForPeer,
2312    session: Session,
2313) -> Result<()> {
2314    session
2315        .db()
2316        .await
2317        .check_user_is_project_host(
2318            ProjectId::from_proto(request.project_id),
2319            session.connection_id,
2320        )
2321        .await?;
2322    let peer_id = request.peer_id.context("invalid peer id")?;
2323    session
2324        .peer
2325        .forward_send(session.connection_id, peer_id.into(), request)?;
2326    Ok(())
2327}
2328
2329/// Notify other participants that a buffer has been updated. This is
2330/// allowed for guests as long as the update is limited to selections.
2331async fn update_buffer(
2332    request: proto::UpdateBuffer,
2333    response: Response<proto::UpdateBuffer>,
2334    session: Session,
2335) -> Result<()> {
2336    let project_id = ProjectId::from_proto(request.project_id);
2337    let mut capability = Capability::ReadOnly;
2338
2339    for op in request.operations.iter() {
2340        match op.variant {
2341            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2342            Some(_) => capability = Capability::ReadWrite,
2343        }
2344    }
2345
2346    let host = {
2347        let guard = session
2348            .db()
2349            .await
2350            .connections_for_buffer_update(project_id, session.connection_id, capability)
2351            .await?;
2352
2353        let (host, guests) = &*guard;
2354
2355        broadcast(
2356            Some(session.connection_id),
2357            guests.clone(),
2358            |connection_id| {
2359                session
2360                    .peer
2361                    .forward_send(session.connection_id, connection_id, request.clone())
2362            },
2363        );
2364
2365        *host
2366    };
2367
2368    if host != session.connection_id {
2369        session
2370            .peer
2371            .forward_request(session.connection_id, host, request.clone())
2372            .await?;
2373    }
2374
2375    response.send(proto::Ack {})?;
2376    Ok(())
2377}
2378
2379async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2380    let project_id = ProjectId::from_proto(message.project_id);
2381
2382    let operation = message.operation.as_ref().context("invalid operation")?;
2383    let capability = match operation.variant.as_ref() {
2384        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2385            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2386                match buffer_op.variant {
2387                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2388                        Capability::ReadOnly
2389                    }
2390                    _ => Capability::ReadWrite,
2391                }
2392            } else {
2393                Capability::ReadWrite
2394            }
2395        }
2396        Some(_) => Capability::ReadWrite,
2397        None => Capability::ReadOnly,
2398    };
2399
2400    let guard = session
2401        .db()
2402        .await
2403        .connections_for_buffer_update(project_id, session.connection_id, capability)
2404        .await?;
2405
2406    let (host, guests) = &*guard;
2407
2408    broadcast(
2409        Some(session.connection_id),
2410        guests.iter().chain([host]).copied(),
2411        |connection_id| {
2412            session
2413                .peer
2414                .forward_send(session.connection_id, connection_id, message.clone())
2415        },
2416    );
2417
2418    Ok(())
2419}
2420
2421/// Notify other participants that a project has been updated.
2422async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2423    request: T,
2424    session: Session,
2425) -> Result<()> {
2426    let project_id = ProjectId::from_proto(request.remote_entity_id());
2427    let project_connection_ids = session
2428        .db()
2429        .await
2430        .project_connection_ids(project_id, session.connection_id, false)
2431        .await?;
2432
2433    broadcast(
2434        Some(session.connection_id),
2435        project_connection_ids.iter().copied(),
2436        |connection_id| {
2437            session
2438                .peer
2439                .forward_send(session.connection_id, connection_id, request.clone())
2440        },
2441    );
2442    Ok(())
2443}
2444
2445/// Start following another user in a call.
2446async fn follow(
2447    request: proto::Follow,
2448    response: Response<proto::Follow>,
2449    session: Session,
2450) -> Result<()> {
2451    let room_id = RoomId::from_proto(request.room_id);
2452    let project_id = request.project_id.map(ProjectId::from_proto);
2453    let leader_id = request.leader_id.context("invalid leader id")?.into();
2454    let follower_id = session.connection_id;
2455
2456    session
2457        .db()
2458        .await
2459        .check_room_participants(room_id, leader_id, session.connection_id)
2460        .await?;
2461
2462    let response_payload = session
2463        .peer
2464        .forward_request(session.connection_id, leader_id, request)
2465        .await?;
2466    response.send(response_payload)?;
2467
2468    if let Some(project_id) = project_id {
2469        let room = session
2470            .db()
2471            .await
2472            .follow(room_id, project_id, leader_id, follower_id)
2473            .await?;
2474        room_updated(&room, &session.peer);
2475    }
2476
2477    Ok(())
2478}
2479
2480/// Stop following another user in a call.
2481async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2482    let room_id = RoomId::from_proto(request.room_id);
2483    let project_id = request.project_id.map(ProjectId::from_proto);
2484    let leader_id = request.leader_id.context("invalid leader id")?.into();
2485    let follower_id = session.connection_id;
2486
2487    session
2488        .db()
2489        .await
2490        .check_room_participants(room_id, leader_id, session.connection_id)
2491        .await?;
2492
2493    session
2494        .peer
2495        .forward_send(session.connection_id, leader_id, request)?;
2496
2497    if let Some(project_id) = project_id {
2498        let room = session
2499            .db()
2500            .await
2501            .unfollow(room_id, project_id, leader_id, follower_id)
2502            .await?;
2503        room_updated(&room, &session.peer);
2504    }
2505
2506    Ok(())
2507}
2508
2509/// Notify everyone following you of your current location.
2510async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2511    let room_id = RoomId::from_proto(request.room_id);
2512    let database = session.db.lock().await;
2513
2514    let connection_ids = if let Some(project_id) = request.project_id {
2515        let project_id = ProjectId::from_proto(project_id);
2516        database
2517            .project_connection_ids(project_id, session.connection_id, true)
2518            .await?
2519    } else {
2520        database
2521            .room_connection_ids(room_id, session.connection_id)
2522            .await?
2523    };
2524
2525    // For now, don't send view update messages back to that view's current leader.
2526    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2527        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2528        _ => None,
2529    });
2530
2531    for connection_id in connection_ids.iter().cloned() {
2532        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2533            session
2534                .peer
2535                .forward_send(session.connection_id, connection_id, request.clone())?;
2536        }
2537    }
2538    Ok(())
2539}
2540
2541/// Get public data about users.
2542async fn get_users(
2543    request: proto::GetUsers,
2544    response: Response<proto::GetUsers>,
2545    session: Session,
2546) -> Result<()> {
2547    let user_ids = request
2548        .user_ids
2549        .into_iter()
2550        .map(UserId::from_proto)
2551        .collect();
2552    let users = session
2553        .db()
2554        .await
2555        .get_users_by_ids(user_ids)
2556        .await?
2557        .into_iter()
2558        .map(|user| proto::User {
2559            id: user.id.to_proto(),
2560            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2561            github_login: user.github_login,
2562            name: user.name,
2563        })
2564        .collect();
2565    response.send(proto::UsersResponse { users })?;
2566    Ok(())
2567}
2568
2569/// Search for users (to invite) buy Github login
2570async fn fuzzy_search_users(
2571    request: proto::FuzzySearchUsers,
2572    response: Response<proto::FuzzySearchUsers>,
2573    session: Session,
2574) -> Result<()> {
2575    let query = request.query;
2576    let users = match query.len() {
2577        0 => vec![],
2578        1 | 2 => session
2579            .db()
2580            .await
2581            .get_user_by_github_login(&query)
2582            .await?
2583            .into_iter()
2584            .collect(),
2585        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2586    };
2587    let users = users
2588        .into_iter()
2589        .filter(|user| user.id != session.user_id())
2590        .map(|user| proto::User {
2591            id: user.id.to_proto(),
2592            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2593            github_login: user.github_login,
2594            name: user.name,
2595        })
2596        .collect();
2597    response.send(proto::UsersResponse { users })?;
2598    Ok(())
2599}
2600
2601/// Send a contact request to another user.
2602async fn request_contact(
2603    request: proto::RequestContact,
2604    response: Response<proto::RequestContact>,
2605    session: Session,
2606) -> Result<()> {
2607    let requester_id = session.user_id();
2608    let responder_id = UserId::from_proto(request.responder_id);
2609    if requester_id == responder_id {
2610        return Err(anyhow!("cannot add yourself as a contact"))?;
2611    }
2612
2613    let notifications = session
2614        .db()
2615        .await
2616        .send_contact_request(requester_id, responder_id)
2617        .await?;
2618
2619    // Update outgoing contact requests of requester
2620    let mut update = proto::UpdateContacts::default();
2621    update.outgoing_requests.push(responder_id.to_proto());
2622    for connection_id in session
2623        .connection_pool()
2624        .await
2625        .user_connection_ids(requester_id)
2626    {
2627        session.peer.send(connection_id, update.clone())?;
2628    }
2629
2630    // Update incoming contact requests of responder
2631    let mut update = proto::UpdateContacts::default();
2632    update
2633        .incoming_requests
2634        .push(proto::IncomingContactRequest {
2635            requester_id: requester_id.to_proto(),
2636        });
2637    let connection_pool = session.connection_pool().await;
2638    for connection_id in connection_pool.user_connection_ids(responder_id) {
2639        session.peer.send(connection_id, update.clone())?;
2640    }
2641
2642    send_notifications(&connection_pool, &session.peer, notifications);
2643
2644    response.send(proto::Ack {})?;
2645    Ok(())
2646}
2647
2648/// Accept or decline a contact request
2649async fn respond_to_contact_request(
2650    request: proto::RespondToContactRequest,
2651    response: Response<proto::RespondToContactRequest>,
2652    session: Session,
2653) -> Result<()> {
2654    let responder_id = session.user_id();
2655    let requester_id = UserId::from_proto(request.requester_id);
2656    let db = session.db().await;
2657    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2658        db.dismiss_contact_notification(responder_id, requester_id)
2659            .await?;
2660    } else {
2661        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2662
2663        let notifications = db
2664            .respond_to_contact_request(responder_id, requester_id, accept)
2665            .await?;
2666        let requester_busy = db.is_user_busy(requester_id).await?;
2667        let responder_busy = db.is_user_busy(responder_id).await?;
2668
2669        let pool = session.connection_pool().await;
2670        // Update responder with new contact
2671        let mut update = proto::UpdateContacts::default();
2672        if accept {
2673            update
2674                .contacts
2675                .push(contact_for_user(requester_id, requester_busy, &pool));
2676        }
2677        update
2678            .remove_incoming_requests
2679            .push(requester_id.to_proto());
2680        for connection_id in pool.user_connection_ids(responder_id) {
2681            session.peer.send(connection_id, update.clone())?;
2682        }
2683
2684        // Update requester with new contact
2685        let mut update = proto::UpdateContacts::default();
2686        if accept {
2687            update
2688                .contacts
2689                .push(contact_for_user(responder_id, responder_busy, &pool));
2690        }
2691        update
2692            .remove_outgoing_requests
2693            .push(responder_id.to_proto());
2694
2695        for connection_id in pool.user_connection_ids(requester_id) {
2696            session.peer.send(connection_id, update.clone())?;
2697        }
2698
2699        send_notifications(&pool, &session.peer, notifications);
2700    }
2701
2702    response.send(proto::Ack {})?;
2703    Ok(())
2704}
2705
2706/// Remove a contact.
2707async fn remove_contact(
2708    request: proto::RemoveContact,
2709    response: Response<proto::RemoveContact>,
2710    session: Session,
2711) -> Result<()> {
2712    let requester_id = session.user_id();
2713    let responder_id = UserId::from_proto(request.user_id);
2714    let db = session.db().await;
2715    let (contact_accepted, deleted_notification_id) =
2716        db.remove_contact(requester_id, responder_id).await?;
2717
2718    let pool = session.connection_pool().await;
2719    // Update outgoing contact requests of requester
2720    let mut update = proto::UpdateContacts::default();
2721    if contact_accepted {
2722        update.remove_contacts.push(responder_id.to_proto());
2723    } else {
2724        update
2725            .remove_outgoing_requests
2726            .push(responder_id.to_proto());
2727    }
2728    for connection_id in pool.user_connection_ids(requester_id) {
2729        session.peer.send(connection_id, update.clone())?;
2730    }
2731
2732    // Update incoming contact requests of responder
2733    let mut update = proto::UpdateContacts::default();
2734    if contact_accepted {
2735        update.remove_contacts.push(requester_id.to_proto());
2736    } else {
2737        update
2738            .remove_incoming_requests
2739            .push(requester_id.to_proto());
2740    }
2741    for connection_id in pool.user_connection_ids(responder_id) {
2742        session.peer.send(connection_id, update.clone())?;
2743        if let Some(notification_id) = deleted_notification_id {
2744            session.peer.send(
2745                connection_id,
2746                proto::DeleteNotification {
2747                    notification_id: notification_id.to_proto(),
2748                },
2749            )?;
2750        }
2751    }
2752
2753    response.send(proto::Ack {})?;
2754    Ok(())
2755}
2756
2757fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2758    version.0.minor() < 139
2759}
2760
2761async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2762    if is_staff {
2763        return Ok(proto::Plan::ZedPro);
2764    }
2765
2766    let subscription = db.get_active_billing_subscription(user_id).await?;
2767    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2768
2769    let plan = if let Some(subscription_kind) = subscription_kind {
2770        match subscription_kind {
2771            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2772            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2773            SubscriptionKind::ZedFree => proto::Plan::Free,
2774        }
2775    } else {
2776        proto::Plan::Free
2777    };
2778
2779    Ok(plan)
2780}
2781
2782async fn make_update_user_plan_message(
2783    user: &User,
2784    is_staff: bool,
2785    db: &Arc<Database>,
2786    llm_db: Option<Arc<LlmDatabase>>,
2787) -> Result<proto::UpdateUserPlan> {
2788    let feature_flags = db.get_user_flags(user.id).await?;
2789    let plan = current_plan(db, user.id, is_staff).await?;
2790    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2791    let billing_preferences = db.get_billing_preferences(user.id).await?;
2792
2793    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2794        let subscription = db.get_active_billing_subscription(user.id).await?;
2795
2796        let subscription_period =
2797            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2798
2799        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2800            llm_db
2801                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2802                .await?
2803        } else {
2804            None
2805        };
2806
2807        (subscription_period, usage)
2808    } else {
2809        (None, None)
2810    };
2811
2812    let bypass_account_age_check = feature_flags
2813        .iter()
2814        .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2815    let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2816        && !bypass_account_age_check
2817        && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2818
2819    Ok(proto::UpdateUserPlan {
2820        plan: plan.into(),
2821        trial_started_at: billing_customer
2822            .as_ref()
2823            .and_then(|billing_customer| billing_customer.trial_started_at)
2824            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2825        is_usage_based_billing_enabled: if is_staff {
2826            Some(true)
2827        } else {
2828            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2829        },
2830        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2831            proto::SubscriptionPeriod {
2832                started_at: started_at.timestamp() as u64,
2833                ended_at: ended_at.timestamp() as u64,
2834            }
2835        }),
2836        account_too_young: Some(account_too_young),
2837        has_overdue_invoices: billing_customer
2838            .map(|billing_customer| billing_customer.has_overdue_invoices),
2839        usage: Some(
2840            usage
2841                .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2842                .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2843        ),
2844    })
2845}
2846
2847fn model_requests_limit(
2848    plan: zed_llm_client::Plan,
2849    feature_flags: &Vec<String>,
2850) -> zed_llm_client::UsageLimit {
2851    match plan.model_requests_limit() {
2852        zed_llm_client::UsageLimit::Limited(limit) => {
2853            let limit = if plan == zed_llm_client::Plan::ZedProTrial
2854                && feature_flags
2855                    .iter()
2856                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2857            {
2858                1_000
2859            } else {
2860                limit
2861            };
2862
2863            zed_llm_client::UsageLimit::Limited(limit)
2864        }
2865        zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2866    }
2867}
2868
2869fn subscription_usage_to_proto(
2870    plan: proto::Plan,
2871    usage: crate::llm::db::subscription_usage::Model,
2872    feature_flags: &Vec<String>,
2873) -> proto::SubscriptionUsage {
2874    let plan = match plan {
2875        proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2876        proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2877        proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2878    };
2879
2880    proto::SubscriptionUsage {
2881        model_requests_usage_amount: usage.model_requests as u32,
2882        model_requests_usage_limit: Some(proto::UsageLimit {
2883            variant: Some(match model_requests_limit(plan, feature_flags) {
2884                zed_llm_client::UsageLimit::Limited(limit) => {
2885                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2886                        limit: limit as u32,
2887                    })
2888                }
2889                zed_llm_client::UsageLimit::Unlimited => {
2890                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2891                }
2892            }),
2893        }),
2894        edit_predictions_usage_amount: usage.edit_predictions as u32,
2895        edit_predictions_usage_limit: Some(proto::UsageLimit {
2896            variant: Some(match plan.edit_predictions_limit() {
2897                zed_llm_client::UsageLimit::Limited(limit) => {
2898                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2899                        limit: limit as u32,
2900                    })
2901                }
2902                zed_llm_client::UsageLimit::Unlimited => {
2903                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2904                }
2905            }),
2906        }),
2907    }
2908}
2909
2910fn make_default_subscription_usage(
2911    plan: proto::Plan,
2912    feature_flags: &Vec<String>,
2913) -> proto::SubscriptionUsage {
2914    let plan = match plan {
2915        proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2916        proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2917        proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2918    };
2919
2920    proto::SubscriptionUsage {
2921        model_requests_usage_amount: 0,
2922        model_requests_usage_limit: Some(proto::UsageLimit {
2923            variant: Some(match model_requests_limit(plan, feature_flags) {
2924                zed_llm_client::UsageLimit::Limited(limit) => {
2925                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2926                        limit: limit as u32,
2927                    })
2928                }
2929                zed_llm_client::UsageLimit::Unlimited => {
2930                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2931                }
2932            }),
2933        }),
2934        edit_predictions_usage_amount: 0,
2935        edit_predictions_usage_limit: Some(proto::UsageLimit {
2936            variant: Some(match plan.edit_predictions_limit() {
2937                zed_llm_client::UsageLimit::Limited(limit) => {
2938                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2939                        limit: limit as u32,
2940                    })
2941                }
2942                zed_llm_client::UsageLimit::Unlimited => {
2943                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2944                }
2945            }),
2946        }),
2947    }
2948}
2949
2950async fn update_user_plan(session: &Session) -> Result<()> {
2951    let db = session.db().await;
2952
2953    let update_user_plan = make_update_user_plan_message(
2954        session.principal.user(),
2955        session.is_staff(),
2956        &db.0,
2957        session.app_state.llm_db.clone(),
2958    )
2959    .await?;
2960
2961    session
2962        .peer
2963        .send(session.connection_id, update_user_plan)
2964        .trace_err();
2965
2966    Ok(())
2967}
2968
2969async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2970    subscribe_user_to_channels(session.user_id(), &session).await?;
2971    Ok(())
2972}
2973
2974async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2975    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2976    let mut pool = session.connection_pool().await;
2977    for membership in &channels_for_user.channel_memberships {
2978        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2979    }
2980    session.peer.send(
2981        session.connection_id,
2982        build_update_user_channels(&channels_for_user),
2983    )?;
2984    session.peer.send(
2985        session.connection_id,
2986        build_channels_update(channels_for_user),
2987    )?;
2988    Ok(())
2989}
2990
2991/// Creates a new channel.
2992async fn create_channel(
2993    request: proto::CreateChannel,
2994    response: Response<proto::CreateChannel>,
2995    session: Session,
2996) -> Result<()> {
2997    let db = session.db().await;
2998
2999    let parent_id = request.parent_id.map(ChannelId::from_proto);
3000    let (channel, membership) = db
3001        .create_channel(&request.name, parent_id, session.user_id())
3002        .await?;
3003
3004    let root_id = channel.root_id();
3005    let channel = Channel::from_model(channel);
3006
3007    response.send(proto::CreateChannelResponse {
3008        channel: Some(channel.to_proto()),
3009        parent_id: request.parent_id,
3010    })?;
3011
3012    let mut connection_pool = session.connection_pool().await;
3013    if let Some(membership) = membership {
3014        connection_pool.subscribe_to_channel(
3015            membership.user_id,
3016            membership.channel_id,
3017            membership.role,
3018        );
3019        let update = proto::UpdateUserChannels {
3020            channel_memberships: vec![proto::ChannelMembership {
3021                channel_id: membership.channel_id.to_proto(),
3022                role: membership.role.into(),
3023            }],
3024            ..Default::default()
3025        };
3026        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3027            session.peer.send(connection_id, update.clone())?;
3028        }
3029    }
3030
3031    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3032        if !role.can_see_channel(channel.visibility) {
3033            continue;
3034        }
3035
3036        let update = proto::UpdateChannels {
3037            channels: vec![channel.to_proto()],
3038            ..Default::default()
3039        };
3040        session.peer.send(connection_id, update.clone())?;
3041    }
3042
3043    Ok(())
3044}
3045
3046/// Delete a channel
3047async fn delete_channel(
3048    request: proto::DeleteChannel,
3049    response: Response<proto::DeleteChannel>,
3050    session: Session,
3051) -> Result<()> {
3052    let db = session.db().await;
3053
3054    let channel_id = request.channel_id;
3055    let (root_channel, removed_channels) = db
3056        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3057        .await?;
3058    response.send(proto::Ack {})?;
3059
3060    // Notify members of removed channels
3061    let mut update = proto::UpdateChannels::default();
3062    update
3063        .delete_channels
3064        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3065
3066    let connection_pool = session.connection_pool().await;
3067    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3068        session.peer.send(connection_id, update.clone())?;
3069    }
3070
3071    Ok(())
3072}
3073
3074/// Invite someone to join a channel.
3075async fn invite_channel_member(
3076    request: proto::InviteChannelMember,
3077    response: Response<proto::InviteChannelMember>,
3078    session: Session,
3079) -> Result<()> {
3080    let db = session.db().await;
3081    let channel_id = ChannelId::from_proto(request.channel_id);
3082    let invitee_id = UserId::from_proto(request.user_id);
3083    let InviteMemberResult {
3084        channel,
3085        notifications,
3086    } = db
3087        .invite_channel_member(
3088            channel_id,
3089            invitee_id,
3090            session.user_id(),
3091            request.role().into(),
3092        )
3093        .await?;
3094
3095    let update = proto::UpdateChannels {
3096        channel_invitations: vec![channel.to_proto()],
3097        ..Default::default()
3098    };
3099
3100    let connection_pool = session.connection_pool().await;
3101    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3102        session.peer.send(connection_id, update.clone())?;
3103    }
3104
3105    send_notifications(&connection_pool, &session.peer, notifications);
3106
3107    response.send(proto::Ack {})?;
3108    Ok(())
3109}
3110
3111/// remove someone from a channel
3112async fn remove_channel_member(
3113    request: proto::RemoveChannelMember,
3114    response: Response<proto::RemoveChannelMember>,
3115    session: Session,
3116) -> Result<()> {
3117    let db = session.db().await;
3118    let channel_id = ChannelId::from_proto(request.channel_id);
3119    let member_id = UserId::from_proto(request.user_id);
3120
3121    let RemoveChannelMemberResult {
3122        membership_update,
3123        notification_id,
3124    } = db
3125        .remove_channel_member(channel_id, member_id, session.user_id())
3126        .await?;
3127
3128    let mut connection_pool = session.connection_pool().await;
3129    notify_membership_updated(
3130        &mut connection_pool,
3131        membership_update,
3132        member_id,
3133        &session.peer,
3134    );
3135    for connection_id in connection_pool.user_connection_ids(member_id) {
3136        if let Some(notification_id) = notification_id {
3137            session
3138                .peer
3139                .send(
3140                    connection_id,
3141                    proto::DeleteNotification {
3142                        notification_id: notification_id.to_proto(),
3143                    },
3144                )
3145                .trace_err();
3146        }
3147    }
3148
3149    response.send(proto::Ack {})?;
3150    Ok(())
3151}
3152
3153/// Toggle the channel between public and private.
3154/// Care is taken to maintain the invariant that public channels only descend from public channels,
3155/// (though members-only channels can appear at any point in the hierarchy).
3156async fn set_channel_visibility(
3157    request: proto::SetChannelVisibility,
3158    response: Response<proto::SetChannelVisibility>,
3159    session: Session,
3160) -> Result<()> {
3161    let db = session.db().await;
3162    let channel_id = ChannelId::from_proto(request.channel_id);
3163    let visibility = request.visibility().into();
3164
3165    let channel_model = db
3166        .set_channel_visibility(channel_id, visibility, session.user_id())
3167        .await?;
3168    let root_id = channel_model.root_id();
3169    let channel = Channel::from_model(channel_model);
3170
3171    let mut connection_pool = session.connection_pool().await;
3172    for (user_id, role) in connection_pool
3173        .channel_user_ids(root_id)
3174        .collect::<Vec<_>>()
3175        .into_iter()
3176    {
3177        let update = if role.can_see_channel(channel.visibility) {
3178            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3179            proto::UpdateChannels {
3180                channels: vec![channel.to_proto()],
3181                ..Default::default()
3182            }
3183        } else {
3184            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3185            proto::UpdateChannels {
3186                delete_channels: vec![channel.id.to_proto()],
3187                ..Default::default()
3188            }
3189        };
3190
3191        for connection_id in connection_pool.user_connection_ids(user_id) {
3192            session.peer.send(connection_id, update.clone())?;
3193        }
3194    }
3195
3196    response.send(proto::Ack {})?;
3197    Ok(())
3198}
3199
3200/// Alter the role for a user in the channel.
3201async fn set_channel_member_role(
3202    request: proto::SetChannelMemberRole,
3203    response: Response<proto::SetChannelMemberRole>,
3204    session: Session,
3205) -> Result<()> {
3206    let db = session.db().await;
3207    let channel_id = ChannelId::from_proto(request.channel_id);
3208    let member_id = UserId::from_proto(request.user_id);
3209    let result = db
3210        .set_channel_member_role(
3211            channel_id,
3212            session.user_id(),
3213            member_id,
3214            request.role().into(),
3215        )
3216        .await?;
3217
3218    match result {
3219        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3220            let mut connection_pool = session.connection_pool().await;
3221            notify_membership_updated(
3222                &mut connection_pool,
3223                membership_update,
3224                member_id,
3225                &session.peer,
3226            )
3227        }
3228        db::SetMemberRoleResult::InviteUpdated(channel) => {
3229            let update = proto::UpdateChannels {
3230                channel_invitations: vec![channel.to_proto()],
3231                ..Default::default()
3232            };
3233
3234            for connection_id in session
3235                .connection_pool()
3236                .await
3237                .user_connection_ids(member_id)
3238            {
3239                session.peer.send(connection_id, update.clone())?;
3240            }
3241        }
3242    }
3243
3244    response.send(proto::Ack {})?;
3245    Ok(())
3246}
3247
3248/// Change the name of a channel
3249async fn rename_channel(
3250    request: proto::RenameChannel,
3251    response: Response<proto::RenameChannel>,
3252    session: Session,
3253) -> Result<()> {
3254    let db = session.db().await;
3255    let channel_id = ChannelId::from_proto(request.channel_id);
3256    let channel_model = db
3257        .rename_channel(channel_id, session.user_id(), &request.name)
3258        .await?;
3259    let root_id = channel_model.root_id();
3260    let channel = Channel::from_model(channel_model);
3261
3262    response.send(proto::RenameChannelResponse {
3263        channel: Some(channel.to_proto()),
3264    })?;
3265
3266    let connection_pool = session.connection_pool().await;
3267    let update = proto::UpdateChannels {
3268        channels: vec![channel.to_proto()],
3269        ..Default::default()
3270    };
3271    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3272        if role.can_see_channel(channel.visibility) {
3273            session.peer.send(connection_id, update.clone())?;
3274        }
3275    }
3276
3277    Ok(())
3278}
3279
3280/// Move a channel to a new parent.
3281async fn move_channel(
3282    request: proto::MoveChannel,
3283    response: Response<proto::MoveChannel>,
3284    session: Session,
3285) -> Result<()> {
3286    let channel_id = ChannelId::from_proto(request.channel_id);
3287    let to = ChannelId::from_proto(request.to);
3288
3289    let (root_id, channels) = session
3290        .db()
3291        .await
3292        .move_channel(channel_id, to, session.user_id())
3293        .await?;
3294
3295    let connection_pool = session.connection_pool().await;
3296    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3297        let channels = channels
3298            .iter()
3299            .filter_map(|channel| {
3300                if role.can_see_channel(channel.visibility) {
3301                    Some(channel.to_proto())
3302                } else {
3303                    None
3304                }
3305            })
3306            .collect::<Vec<_>>();
3307        if channels.is_empty() {
3308            continue;
3309        }
3310
3311        let update = proto::UpdateChannels {
3312            channels,
3313            ..Default::default()
3314        };
3315
3316        session.peer.send(connection_id, update.clone())?;
3317    }
3318
3319    response.send(Ack {})?;
3320    Ok(())
3321}
3322
3323async fn reorder_channel(
3324    request: proto::ReorderChannel,
3325    response: Response<proto::ReorderChannel>,
3326    session: Session,
3327) -> Result<()> {
3328    let channel_id = ChannelId::from_proto(request.channel_id);
3329    let direction = request.direction();
3330
3331    let updated_channels = session
3332        .db()
3333        .await
3334        .reorder_channel(channel_id, direction, session.user_id())
3335        .await?;
3336
3337    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3338        let connection_pool = session.connection_pool().await;
3339        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3340            let channels = updated_channels
3341                .iter()
3342                .filter_map(|channel| {
3343                    if role.can_see_channel(channel.visibility) {
3344                        Some(channel.to_proto())
3345                    } else {
3346                        None
3347                    }
3348                })
3349                .collect::<Vec<_>>();
3350
3351            if channels.is_empty() {
3352                continue;
3353            }
3354
3355            let update = proto::UpdateChannels {
3356                channels,
3357                ..Default::default()
3358            };
3359
3360            session.peer.send(connection_id, update.clone())?;
3361        }
3362    }
3363
3364    response.send(Ack {})?;
3365    Ok(())
3366}
3367
3368/// Get the list of channel members
3369async fn get_channel_members(
3370    request: proto::GetChannelMembers,
3371    response: Response<proto::GetChannelMembers>,
3372    session: Session,
3373) -> Result<()> {
3374    let db = session.db().await;
3375    let channel_id = ChannelId::from_proto(request.channel_id);
3376    let limit = if request.limit == 0 {
3377        u16::MAX as u64
3378    } else {
3379        request.limit
3380    };
3381    let (members, users) = db
3382        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3383        .await?;
3384    response.send(proto::GetChannelMembersResponse { members, users })?;
3385    Ok(())
3386}
3387
3388/// Accept or decline a channel invitation.
3389async fn respond_to_channel_invite(
3390    request: proto::RespondToChannelInvite,
3391    response: Response<proto::RespondToChannelInvite>,
3392    session: Session,
3393) -> Result<()> {
3394    let db = session.db().await;
3395    let channel_id = ChannelId::from_proto(request.channel_id);
3396    let RespondToChannelInvite {
3397        membership_update,
3398        notifications,
3399    } = db
3400        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3401        .await?;
3402
3403    let mut connection_pool = session.connection_pool().await;
3404    if let Some(membership_update) = membership_update {
3405        notify_membership_updated(
3406            &mut connection_pool,
3407            membership_update,
3408            session.user_id(),
3409            &session.peer,
3410        );
3411    } else {
3412        let update = proto::UpdateChannels {
3413            remove_channel_invitations: vec![channel_id.to_proto()],
3414            ..Default::default()
3415        };
3416
3417        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3418            session.peer.send(connection_id, update.clone())?;
3419        }
3420    };
3421
3422    send_notifications(&connection_pool, &session.peer, notifications);
3423
3424    response.send(proto::Ack {})?;
3425
3426    Ok(())
3427}
3428
3429/// Join the channels' room
3430async fn join_channel(
3431    request: proto::JoinChannel,
3432    response: Response<proto::JoinChannel>,
3433    session: Session,
3434) -> Result<()> {
3435    let channel_id = ChannelId::from_proto(request.channel_id);
3436    join_channel_internal(channel_id, Box::new(response), session).await
3437}
3438
3439trait JoinChannelInternalResponse {
3440    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3441}
3442impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3443    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3444        Response::<proto::JoinChannel>::send(self, result)
3445    }
3446}
3447impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3448    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3449        Response::<proto::JoinRoom>::send(self, result)
3450    }
3451}
3452
3453async fn join_channel_internal(
3454    channel_id: ChannelId,
3455    response: Box<impl JoinChannelInternalResponse>,
3456    session: Session,
3457) -> Result<()> {
3458    let joined_room = {
3459        let mut db = session.db().await;
3460        // If zed quits without leaving the room, and the user re-opens zed before the
3461        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3462        // room they were in.
3463        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3464            tracing::info!(
3465                stale_connection_id = %connection,
3466                "cleaning up stale connection",
3467            );
3468            drop(db);
3469            leave_room_for_session(&session, connection).await?;
3470            db = session.db().await;
3471        }
3472
3473        let (joined_room, membership_updated, role) = db
3474            .join_channel(channel_id, session.user_id(), session.connection_id)
3475            .await?;
3476
3477        let live_kit_connection_info =
3478            session
3479                .app_state
3480                .livekit_client
3481                .as_ref()
3482                .and_then(|live_kit| {
3483                    let (can_publish, token) = if role == ChannelRole::Guest {
3484                        (
3485                            false,
3486                            live_kit
3487                                .guest_token(
3488                                    &joined_room.room.livekit_room,
3489                                    &session.user_id().to_string(),
3490                                )
3491                                .trace_err()?,
3492                        )
3493                    } else {
3494                        (
3495                            true,
3496                            live_kit
3497                                .room_token(
3498                                    &joined_room.room.livekit_room,
3499                                    &session.user_id().to_string(),
3500                                )
3501                                .trace_err()?,
3502                        )
3503                    };
3504
3505                    Some(LiveKitConnectionInfo {
3506                        server_url: live_kit.url().into(),
3507                        token,
3508                        can_publish,
3509                    })
3510                });
3511
3512        response.send(proto::JoinRoomResponse {
3513            room: Some(joined_room.room.clone()),
3514            channel_id: joined_room
3515                .channel
3516                .as_ref()
3517                .map(|channel| channel.id.to_proto()),
3518            live_kit_connection_info,
3519        })?;
3520
3521        let mut connection_pool = session.connection_pool().await;
3522        if let Some(membership_updated) = membership_updated {
3523            notify_membership_updated(
3524                &mut connection_pool,
3525                membership_updated,
3526                session.user_id(),
3527                &session.peer,
3528            );
3529        }
3530
3531        room_updated(&joined_room.room, &session.peer);
3532
3533        joined_room
3534    };
3535
3536    channel_updated(
3537        &joined_room.channel.context("channel not returned")?,
3538        &joined_room.room,
3539        &session.peer,
3540        &*session.connection_pool().await,
3541    );
3542
3543    update_user_contacts(session.user_id(), &session).await?;
3544    Ok(())
3545}
3546
3547/// Start editing the channel notes
3548async fn join_channel_buffer(
3549    request: proto::JoinChannelBuffer,
3550    response: Response<proto::JoinChannelBuffer>,
3551    session: Session,
3552) -> Result<()> {
3553    let db = session.db().await;
3554    let channel_id = ChannelId::from_proto(request.channel_id);
3555
3556    let open_response = db
3557        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3558        .await?;
3559
3560    let collaborators = open_response.collaborators.clone();
3561    response.send(open_response)?;
3562
3563    let update = UpdateChannelBufferCollaborators {
3564        channel_id: channel_id.to_proto(),
3565        collaborators: collaborators.clone(),
3566    };
3567    channel_buffer_updated(
3568        session.connection_id,
3569        collaborators
3570            .iter()
3571            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3572        &update,
3573        &session.peer,
3574    );
3575
3576    Ok(())
3577}
3578
3579/// Edit the channel notes
3580async fn update_channel_buffer(
3581    request: proto::UpdateChannelBuffer,
3582    session: Session,
3583) -> Result<()> {
3584    let db = session.db().await;
3585    let channel_id = ChannelId::from_proto(request.channel_id);
3586
3587    let (collaborators, epoch, version) = db
3588        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3589        .await?;
3590
3591    channel_buffer_updated(
3592        session.connection_id,
3593        collaborators.clone(),
3594        &proto::UpdateChannelBuffer {
3595            channel_id: channel_id.to_proto(),
3596            operations: request.operations,
3597        },
3598        &session.peer,
3599    );
3600
3601    let pool = &*session.connection_pool().await;
3602
3603    let non_collaborators =
3604        pool.channel_connection_ids(channel_id)
3605            .filter_map(|(connection_id, _)| {
3606                if collaborators.contains(&connection_id) {
3607                    None
3608                } else {
3609                    Some(connection_id)
3610                }
3611            });
3612
3613    broadcast(None, non_collaborators, |peer_id| {
3614        session.peer.send(
3615            peer_id,
3616            proto::UpdateChannels {
3617                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3618                    channel_id: channel_id.to_proto(),
3619                    epoch: epoch as u64,
3620                    version: version.clone(),
3621                }],
3622                ..Default::default()
3623            },
3624        )
3625    });
3626
3627    Ok(())
3628}
3629
3630/// Rejoin the channel notes after a connection blip
3631async fn rejoin_channel_buffers(
3632    request: proto::RejoinChannelBuffers,
3633    response: Response<proto::RejoinChannelBuffers>,
3634    session: Session,
3635) -> Result<()> {
3636    let db = session.db().await;
3637    let buffers = db
3638        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3639        .await?;
3640
3641    for rejoined_buffer in &buffers {
3642        let collaborators_to_notify = rejoined_buffer
3643            .buffer
3644            .collaborators
3645            .iter()
3646            .filter_map(|c| Some(c.peer_id?.into()));
3647        channel_buffer_updated(
3648            session.connection_id,
3649            collaborators_to_notify,
3650            &proto::UpdateChannelBufferCollaborators {
3651                channel_id: rejoined_buffer.buffer.channel_id,
3652                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3653            },
3654            &session.peer,
3655        );
3656    }
3657
3658    response.send(proto::RejoinChannelBuffersResponse {
3659        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3660    })?;
3661
3662    Ok(())
3663}
3664
3665/// Stop editing the channel notes
3666async fn leave_channel_buffer(
3667    request: proto::LeaveChannelBuffer,
3668    response: Response<proto::LeaveChannelBuffer>,
3669    session: Session,
3670) -> Result<()> {
3671    let db = session.db().await;
3672    let channel_id = ChannelId::from_proto(request.channel_id);
3673
3674    let left_buffer = db
3675        .leave_channel_buffer(channel_id, session.connection_id)
3676        .await?;
3677
3678    response.send(Ack {})?;
3679
3680    channel_buffer_updated(
3681        session.connection_id,
3682        left_buffer.connections,
3683        &proto::UpdateChannelBufferCollaborators {
3684            channel_id: channel_id.to_proto(),
3685            collaborators: left_buffer.collaborators,
3686        },
3687        &session.peer,
3688    );
3689
3690    Ok(())
3691}
3692
3693fn channel_buffer_updated<T: EnvelopedMessage>(
3694    sender_id: ConnectionId,
3695    collaborators: impl IntoIterator<Item = ConnectionId>,
3696    message: &T,
3697    peer: &Peer,
3698) {
3699    broadcast(Some(sender_id), collaborators, |peer_id| {
3700        peer.send(peer_id, message.clone())
3701    });
3702}
3703
3704fn send_notifications(
3705    connection_pool: &ConnectionPool,
3706    peer: &Peer,
3707    notifications: db::NotificationBatch,
3708) {
3709    for (user_id, notification) in notifications {
3710        for connection_id in connection_pool.user_connection_ids(user_id) {
3711            if let Err(error) = peer.send(
3712                connection_id,
3713                proto::AddNotification {
3714                    notification: Some(notification.clone()),
3715                },
3716            ) {
3717                tracing::error!(
3718                    "failed to send notification to {:?} {}",
3719                    connection_id,
3720                    error
3721                );
3722            }
3723        }
3724    }
3725}
3726
3727/// Send a message to the channel
3728async fn send_channel_message(
3729    request: proto::SendChannelMessage,
3730    response: Response<proto::SendChannelMessage>,
3731    session: Session,
3732) -> Result<()> {
3733    // Validate the message body.
3734    let body = request.body.trim().to_string();
3735    if body.len() > MAX_MESSAGE_LEN {
3736        return Err(anyhow!("message is too long"))?;
3737    }
3738    if body.is_empty() {
3739        return Err(anyhow!("message can't be blank"))?;
3740    }
3741
3742    // TODO: adjust mentions if body is trimmed
3743
3744    let timestamp = OffsetDateTime::now_utc();
3745    let nonce = request.nonce.context("nonce can't be blank")?;
3746
3747    let channel_id = ChannelId::from_proto(request.channel_id);
3748    let CreatedChannelMessage {
3749        message_id,
3750        participant_connection_ids,
3751        notifications,
3752    } = session
3753        .db()
3754        .await
3755        .create_channel_message(
3756            channel_id,
3757            session.user_id(),
3758            &body,
3759            &request.mentions,
3760            timestamp,
3761            nonce.clone().into(),
3762            request.reply_to_message_id.map(MessageId::from_proto),
3763        )
3764        .await?;
3765
3766    let message = proto::ChannelMessage {
3767        sender_id: session.user_id().to_proto(),
3768        id: message_id.to_proto(),
3769        body,
3770        mentions: request.mentions,
3771        timestamp: timestamp.unix_timestamp() as u64,
3772        nonce: Some(nonce),
3773        reply_to_message_id: request.reply_to_message_id,
3774        edited_at: None,
3775    };
3776    broadcast(
3777        Some(session.connection_id),
3778        participant_connection_ids.clone(),
3779        |connection| {
3780            session.peer.send(
3781                connection,
3782                proto::ChannelMessageSent {
3783                    channel_id: channel_id.to_proto(),
3784                    message: Some(message.clone()),
3785                },
3786            )
3787        },
3788    );
3789    response.send(proto::SendChannelMessageResponse {
3790        message: Some(message),
3791    })?;
3792
3793    let pool = &*session.connection_pool().await;
3794    let non_participants =
3795        pool.channel_connection_ids(channel_id)
3796            .filter_map(|(connection_id, _)| {
3797                if participant_connection_ids.contains(&connection_id) {
3798                    None
3799                } else {
3800                    Some(connection_id)
3801                }
3802            });
3803    broadcast(None, non_participants, |peer_id| {
3804        session.peer.send(
3805            peer_id,
3806            proto::UpdateChannels {
3807                latest_channel_message_ids: vec![proto::ChannelMessageId {
3808                    channel_id: channel_id.to_proto(),
3809                    message_id: message_id.to_proto(),
3810                }],
3811                ..Default::default()
3812            },
3813        )
3814    });
3815    send_notifications(pool, &session.peer, notifications);
3816
3817    Ok(())
3818}
3819
3820/// Delete a channel message
3821async fn remove_channel_message(
3822    request: proto::RemoveChannelMessage,
3823    response: Response<proto::RemoveChannelMessage>,
3824    session: Session,
3825) -> Result<()> {
3826    let channel_id = ChannelId::from_proto(request.channel_id);
3827    let message_id = MessageId::from_proto(request.message_id);
3828    let (connection_ids, existing_notification_ids) = session
3829        .db()
3830        .await
3831        .remove_channel_message(channel_id, message_id, session.user_id())
3832        .await?;
3833
3834    broadcast(
3835        Some(session.connection_id),
3836        connection_ids,
3837        move |connection| {
3838            session.peer.send(connection, request.clone())?;
3839
3840            for notification_id in &existing_notification_ids {
3841                session.peer.send(
3842                    connection,
3843                    proto::DeleteNotification {
3844                        notification_id: (*notification_id).to_proto(),
3845                    },
3846                )?;
3847            }
3848
3849            Ok(())
3850        },
3851    );
3852    response.send(proto::Ack {})?;
3853    Ok(())
3854}
3855
3856async fn update_channel_message(
3857    request: proto::UpdateChannelMessage,
3858    response: Response<proto::UpdateChannelMessage>,
3859    session: Session,
3860) -> Result<()> {
3861    let channel_id = ChannelId::from_proto(request.channel_id);
3862    let message_id = MessageId::from_proto(request.message_id);
3863    let updated_at = OffsetDateTime::now_utc();
3864    let UpdatedChannelMessage {
3865        message_id,
3866        participant_connection_ids,
3867        notifications,
3868        reply_to_message_id,
3869        timestamp,
3870        deleted_mention_notification_ids,
3871        updated_mention_notifications,
3872    } = session
3873        .db()
3874        .await
3875        .update_channel_message(
3876            channel_id,
3877            message_id,
3878            session.user_id(),
3879            request.body.as_str(),
3880            &request.mentions,
3881            updated_at,
3882        )
3883        .await?;
3884
3885    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3886
3887    let message = proto::ChannelMessage {
3888        sender_id: session.user_id().to_proto(),
3889        id: message_id.to_proto(),
3890        body: request.body.clone(),
3891        mentions: request.mentions.clone(),
3892        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3893        nonce: Some(nonce),
3894        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3895        edited_at: Some(updated_at.unix_timestamp() as u64),
3896    };
3897
3898    response.send(proto::Ack {})?;
3899
3900    let pool = &*session.connection_pool().await;
3901    broadcast(
3902        Some(session.connection_id),
3903        participant_connection_ids,
3904        |connection| {
3905            session.peer.send(
3906                connection,
3907                proto::ChannelMessageUpdate {
3908                    channel_id: channel_id.to_proto(),
3909                    message: Some(message.clone()),
3910                },
3911            )?;
3912
3913            for notification_id in &deleted_mention_notification_ids {
3914                session.peer.send(
3915                    connection,
3916                    proto::DeleteNotification {
3917                        notification_id: (*notification_id).to_proto(),
3918                    },
3919                )?;
3920            }
3921
3922            for notification in &updated_mention_notifications {
3923                session.peer.send(
3924                    connection,
3925                    proto::UpdateNotification {
3926                        notification: Some(notification.clone()),
3927                    },
3928                )?;
3929            }
3930
3931            Ok(())
3932        },
3933    );
3934
3935    send_notifications(pool, &session.peer, notifications);
3936
3937    Ok(())
3938}
3939
3940/// Mark a channel message as read
3941async fn acknowledge_channel_message(
3942    request: proto::AckChannelMessage,
3943    session: Session,
3944) -> Result<()> {
3945    let channel_id = ChannelId::from_proto(request.channel_id);
3946    let message_id = MessageId::from_proto(request.message_id);
3947    let notifications = session
3948        .db()
3949        .await
3950        .observe_channel_message(channel_id, session.user_id(), message_id)
3951        .await?;
3952    send_notifications(
3953        &*session.connection_pool().await,
3954        &session.peer,
3955        notifications,
3956    );
3957    Ok(())
3958}
3959
3960/// Mark a buffer version as synced
3961async fn acknowledge_buffer_version(
3962    request: proto::AckBufferOperation,
3963    session: Session,
3964) -> Result<()> {
3965    let buffer_id = BufferId::from_proto(request.buffer_id);
3966    session
3967        .db()
3968        .await
3969        .observe_buffer_version(
3970            buffer_id,
3971            session.user_id(),
3972            request.epoch as i32,
3973            &request.version,
3974        )
3975        .await?;
3976    Ok(())
3977}
3978
3979/// Get a Supermaven API key for the user
3980async fn get_supermaven_api_key(
3981    _request: proto::GetSupermavenApiKey,
3982    response: Response<proto::GetSupermavenApiKey>,
3983    session: Session,
3984) -> Result<()> {
3985    let user_id: String = session.user_id().to_string();
3986    if !session.is_staff() {
3987        return Err(anyhow!("supermaven not enabled for this account"))?;
3988    }
3989
3990    let email = session.email().context("user must have an email")?;
3991
3992    let supermaven_admin_api = session
3993        .supermaven_client
3994        .as_ref()
3995        .context("supermaven not configured")?;
3996
3997    let result = supermaven_admin_api
3998        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3999        .await?;
4000
4001    response.send(proto::GetSupermavenApiKeyResponse {
4002        api_key: result.api_key,
4003    })?;
4004
4005    Ok(())
4006}
4007
4008/// Start receiving chat updates for a channel
4009async fn join_channel_chat(
4010    request: proto::JoinChannelChat,
4011    response: Response<proto::JoinChannelChat>,
4012    session: Session,
4013) -> Result<()> {
4014    let channel_id = ChannelId::from_proto(request.channel_id);
4015
4016    let db = session.db().await;
4017    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4018        .await?;
4019    let messages = db
4020        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4021        .await?;
4022    response.send(proto::JoinChannelChatResponse {
4023        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4024        messages,
4025    })?;
4026    Ok(())
4027}
4028
4029/// Stop receiving chat updates for a channel
4030async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4031    let channel_id = ChannelId::from_proto(request.channel_id);
4032    session
4033        .db()
4034        .await
4035        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4036        .await?;
4037    Ok(())
4038}
4039
4040/// Retrieve the chat history for a channel
4041async fn get_channel_messages(
4042    request: proto::GetChannelMessages,
4043    response: Response<proto::GetChannelMessages>,
4044    session: Session,
4045) -> Result<()> {
4046    let channel_id = ChannelId::from_proto(request.channel_id);
4047    let messages = session
4048        .db()
4049        .await
4050        .get_channel_messages(
4051            channel_id,
4052            session.user_id(),
4053            MESSAGE_COUNT_PER_PAGE,
4054            Some(MessageId::from_proto(request.before_message_id)),
4055        )
4056        .await?;
4057    response.send(proto::GetChannelMessagesResponse {
4058        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4059        messages,
4060    })?;
4061    Ok(())
4062}
4063
4064/// Retrieve specific chat messages
4065async fn get_channel_messages_by_id(
4066    request: proto::GetChannelMessagesById,
4067    response: Response<proto::GetChannelMessagesById>,
4068    session: Session,
4069) -> Result<()> {
4070    let message_ids = request
4071        .message_ids
4072        .iter()
4073        .map(|id| MessageId::from_proto(*id))
4074        .collect::<Vec<_>>();
4075    let messages = session
4076        .db()
4077        .await
4078        .get_channel_messages_by_id(session.user_id(), &message_ids)
4079        .await?;
4080    response.send(proto::GetChannelMessagesResponse {
4081        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4082        messages,
4083    })?;
4084    Ok(())
4085}
4086
4087/// Retrieve the current users notifications
4088async fn get_notifications(
4089    request: proto::GetNotifications,
4090    response: Response<proto::GetNotifications>,
4091    session: Session,
4092) -> Result<()> {
4093    let notifications = session
4094        .db()
4095        .await
4096        .get_notifications(
4097            session.user_id(),
4098            NOTIFICATION_COUNT_PER_PAGE,
4099            request.before_id.map(db::NotificationId::from_proto),
4100        )
4101        .await?;
4102    response.send(proto::GetNotificationsResponse {
4103        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4104        notifications,
4105    })?;
4106    Ok(())
4107}
4108
4109/// Mark notifications as read
4110async fn mark_notification_as_read(
4111    request: proto::MarkNotificationRead,
4112    response: Response<proto::MarkNotificationRead>,
4113    session: Session,
4114) -> Result<()> {
4115    let database = &session.db().await;
4116    let notifications = database
4117        .mark_notification_as_read_by_id(
4118            session.user_id(),
4119            NotificationId::from_proto(request.notification_id),
4120        )
4121        .await?;
4122    send_notifications(
4123        &*session.connection_pool().await,
4124        &session.peer,
4125        notifications,
4126    );
4127    response.send(proto::Ack {})?;
4128    Ok(())
4129}
4130
4131/// Get the current users information
4132async fn get_private_user_info(
4133    _request: proto::GetPrivateUserInfo,
4134    response: Response<proto::GetPrivateUserInfo>,
4135    session: Session,
4136) -> Result<()> {
4137    let db = session.db().await;
4138
4139    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4140    let user = db
4141        .get_user_by_id(session.user_id())
4142        .await?
4143        .context("user not found")?;
4144    let flags = db.get_user_flags(session.user_id()).await?;
4145
4146    response.send(proto::GetPrivateUserInfoResponse {
4147        metrics_id,
4148        staff: user.admin,
4149        flags,
4150        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4151    })?;
4152    Ok(())
4153}
4154
4155/// Accept the terms of service (tos) on behalf of the current user
4156async fn accept_terms_of_service(
4157    _request: proto::AcceptTermsOfService,
4158    response: Response<proto::AcceptTermsOfService>,
4159    session: Session,
4160) -> Result<()> {
4161    let db = session.db().await;
4162
4163    let accepted_tos_at = Utc::now();
4164    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4165        .await?;
4166
4167    response.send(proto::AcceptTermsOfServiceResponse {
4168        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4169    })?;
4170
4171    // When the user accepts the terms of service, we want to refresh their LLM
4172    // token to grant access.
4173    session
4174        .peer
4175        .send(session.connection_id, proto::RefreshLlmToken {})?;
4176
4177    Ok(())
4178}
4179
4180async fn get_llm_api_token(
4181    _request: proto::GetLlmToken,
4182    response: Response<proto::GetLlmToken>,
4183    session: Session,
4184) -> Result<()> {
4185    let db = session.db().await;
4186
4187    let flags = db.get_user_flags(session.user_id()).await?;
4188
4189    let user_id = session.user_id();
4190    let user = db
4191        .get_user_by_id(user_id)
4192        .await?
4193        .with_context(|| format!("user {user_id} not found"))?;
4194
4195    if user.accepted_tos_at.is_none() {
4196        Err(anyhow!("terms of service not accepted"))?
4197    }
4198
4199    let stripe_client = session
4200        .app_state
4201        .stripe_client
4202        .as_ref()
4203        .context("failed to retrieve Stripe client")?;
4204
4205    let stripe_billing = session
4206        .app_state
4207        .stripe_billing
4208        .as_ref()
4209        .context("failed to retrieve Stripe billing object")?;
4210
4211    let billing_customer = if let Some(billing_customer) =
4212        db.get_billing_customer_by_user_id(user.id).await?
4213    {
4214        billing_customer
4215    } else {
4216        let customer_id = stripe_billing
4217            .find_or_create_customer_by_email(user.email_address.as_deref())
4218            .await?;
4219
4220        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4221            .await?
4222            .context("billing customer not found")?
4223    };
4224
4225    let billing_subscription =
4226        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4227            billing_subscription
4228        } else {
4229            let stripe_customer_id =
4230                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4231
4232            let stripe_subscription = stripe_billing
4233                .subscribe_to_zed_free(stripe_customer_id)
4234                .await?;
4235
4236            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4237                billing_customer_id: billing_customer.id,
4238                kind: Some(SubscriptionKind::ZedFree),
4239                stripe_subscription_id: stripe_subscription.id.to_string(),
4240                stripe_subscription_status: stripe_subscription.status.into(),
4241                stripe_cancellation_reason: None,
4242                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4243                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4244            })
4245            .await?
4246        };
4247
4248    let billing_preferences = db.get_billing_preferences(user.id).await?;
4249
4250    let token = LlmTokenClaims::create(
4251        &user,
4252        session.is_staff(),
4253        billing_customer,
4254        billing_preferences,
4255        &flags,
4256        billing_subscription,
4257        session.system_id.clone(),
4258        &session.app_state.config,
4259    )?;
4260    response.send(proto::GetLlmTokenResponse { token })?;
4261    Ok(())
4262}
4263
4264fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4265    let message = match message {
4266        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4267        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4268        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4269        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4270        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4271            code: frame.code.into(),
4272            reason: frame.reason.as_str().to_owned().into(),
4273        })),
4274        // We should never receive a frame while reading the message, according
4275        // to the `tungstenite` maintainers:
4276        //
4277        // > It cannot occur when you read messages from the WebSocket, but it
4278        // > can be used when you want to send the raw frames (e.g. you want to
4279        // > send the frames to the WebSocket without composing the full message first).
4280        // >
4281        // > — https://github.com/snapview/tungstenite-rs/issues/268
4282        TungsteniteMessage::Frame(_) => {
4283            bail!("received an unexpected frame while reading the message")
4284        }
4285    };
4286
4287    Ok(message)
4288}
4289
4290fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4291    match message {
4292        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4293        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4294        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4295        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4296        AxumMessage::Close(frame) => {
4297            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4298                code: frame.code.into(),
4299                reason: frame.reason.as_ref().into(),
4300            }))
4301        }
4302    }
4303}
4304
4305fn notify_membership_updated(
4306    connection_pool: &mut ConnectionPool,
4307    result: MembershipUpdated,
4308    user_id: UserId,
4309    peer: &Peer,
4310) {
4311    for membership in &result.new_channels.channel_memberships {
4312        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4313    }
4314    for channel_id in &result.removed_channels {
4315        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4316    }
4317
4318    let user_channels_update = proto::UpdateUserChannels {
4319        channel_memberships: result
4320            .new_channels
4321            .channel_memberships
4322            .iter()
4323            .map(|cm| proto::ChannelMembership {
4324                channel_id: cm.channel_id.to_proto(),
4325                role: cm.role.into(),
4326            })
4327            .collect(),
4328        ..Default::default()
4329    };
4330
4331    let mut update = build_channels_update(result.new_channels);
4332    update.delete_channels = result
4333        .removed_channels
4334        .into_iter()
4335        .map(|id| id.to_proto())
4336        .collect();
4337    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4338
4339    for connection_id in connection_pool.user_connection_ids(user_id) {
4340        peer.send(connection_id, user_channels_update.clone())
4341            .trace_err();
4342        peer.send(connection_id, update.clone()).trace_err();
4343    }
4344}
4345
4346fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4347    proto::UpdateUserChannels {
4348        channel_memberships: channels
4349            .channel_memberships
4350            .iter()
4351            .map(|m| proto::ChannelMembership {
4352                channel_id: m.channel_id.to_proto(),
4353                role: m.role.into(),
4354            })
4355            .collect(),
4356        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4357        observed_channel_message_id: channels.observed_channel_messages.clone(),
4358    }
4359}
4360
4361fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4362    let mut update = proto::UpdateChannels::default();
4363
4364    for channel in channels.channels {
4365        update.channels.push(channel.to_proto());
4366    }
4367
4368    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4369    update.latest_channel_message_ids = channels.latest_channel_messages;
4370
4371    for (channel_id, participants) in channels.channel_participants {
4372        update
4373            .channel_participants
4374            .push(proto::ChannelParticipants {
4375                channel_id: channel_id.to_proto(),
4376                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4377            });
4378    }
4379
4380    for channel in channels.invited_channels {
4381        update.channel_invitations.push(channel.to_proto());
4382    }
4383
4384    update
4385}
4386
4387fn build_initial_contacts_update(
4388    contacts: Vec<db::Contact>,
4389    pool: &ConnectionPool,
4390) -> proto::UpdateContacts {
4391    let mut update = proto::UpdateContacts::default();
4392
4393    for contact in contacts {
4394        match contact {
4395            db::Contact::Accepted { user_id, busy } => {
4396                update.contacts.push(contact_for_user(user_id, busy, pool));
4397            }
4398            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4399            db::Contact::Incoming { user_id } => {
4400                update
4401                    .incoming_requests
4402                    .push(proto::IncomingContactRequest {
4403                        requester_id: user_id.to_proto(),
4404                    })
4405            }
4406        }
4407    }
4408
4409    update
4410}
4411
4412fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4413    proto::Contact {
4414        user_id: user_id.to_proto(),
4415        online: pool.is_user_online(user_id),
4416        busy,
4417    }
4418}
4419
4420fn room_updated(room: &proto::Room, peer: &Peer) {
4421    broadcast(
4422        None,
4423        room.participants
4424            .iter()
4425            .filter_map(|participant| Some(participant.peer_id?.into())),
4426        |peer_id| {
4427            peer.send(
4428                peer_id,
4429                proto::RoomUpdated {
4430                    room: Some(room.clone()),
4431                },
4432            )
4433        },
4434    );
4435}
4436
4437fn channel_updated(
4438    channel: &db::channel::Model,
4439    room: &proto::Room,
4440    peer: &Peer,
4441    pool: &ConnectionPool,
4442) {
4443    let participants = room
4444        .participants
4445        .iter()
4446        .map(|p| p.user_id)
4447        .collect::<Vec<_>>();
4448
4449    broadcast(
4450        None,
4451        pool.channel_connection_ids(channel.root_id())
4452            .filter_map(|(channel_id, role)| {
4453                role.can_see_channel(channel.visibility)
4454                    .then_some(channel_id)
4455            }),
4456        |peer_id| {
4457            peer.send(
4458                peer_id,
4459                proto::UpdateChannels {
4460                    channel_participants: vec![proto::ChannelParticipants {
4461                        channel_id: channel.id.to_proto(),
4462                        participant_user_ids: participants.clone(),
4463                    }],
4464                    ..Default::default()
4465                },
4466            )
4467        },
4468    );
4469}
4470
4471async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4472    let db = session.db().await;
4473
4474    let contacts = db.get_contacts(user_id).await?;
4475    let busy = db.is_user_busy(user_id).await?;
4476
4477    let pool = session.connection_pool().await;
4478    let updated_contact = contact_for_user(user_id, busy, &pool);
4479    for contact in contacts {
4480        if let db::Contact::Accepted {
4481            user_id: contact_user_id,
4482            ..
4483        } = contact
4484        {
4485            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4486                session
4487                    .peer
4488                    .send(
4489                        contact_conn_id,
4490                        proto::UpdateContacts {
4491                            contacts: vec![updated_contact.clone()],
4492                            remove_contacts: Default::default(),
4493                            incoming_requests: Default::default(),
4494                            remove_incoming_requests: Default::default(),
4495                            outgoing_requests: Default::default(),
4496                            remove_outgoing_requests: Default::default(),
4497                        },
4498                    )
4499                    .trace_err();
4500            }
4501        }
4502    }
4503    Ok(())
4504}
4505
4506async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4507    let mut contacts_to_update = HashSet::default();
4508
4509    let room_id;
4510    let canceled_calls_to_user_ids;
4511    let livekit_room;
4512    let delete_livekit_room;
4513    let room;
4514    let channel;
4515
4516    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4517        contacts_to_update.insert(session.user_id());
4518
4519        for project in left_room.left_projects.values() {
4520            project_left(project, session);
4521        }
4522
4523        room_id = RoomId::from_proto(left_room.room.id);
4524        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4525        livekit_room = mem::take(&mut left_room.room.livekit_room);
4526        delete_livekit_room = left_room.deleted;
4527        room = mem::take(&mut left_room.room);
4528        channel = mem::take(&mut left_room.channel);
4529
4530        room_updated(&room, &session.peer);
4531    } else {
4532        return Ok(());
4533    }
4534
4535    if let Some(channel) = channel {
4536        channel_updated(
4537            &channel,
4538            &room,
4539            &session.peer,
4540            &*session.connection_pool().await,
4541        );
4542    }
4543
4544    {
4545        let pool = session.connection_pool().await;
4546        for canceled_user_id in canceled_calls_to_user_ids {
4547            for connection_id in pool.user_connection_ids(canceled_user_id) {
4548                session
4549                    .peer
4550                    .send(
4551                        connection_id,
4552                        proto::CallCanceled {
4553                            room_id: room_id.to_proto(),
4554                        },
4555                    )
4556                    .trace_err();
4557            }
4558            contacts_to_update.insert(canceled_user_id);
4559        }
4560    }
4561
4562    for contact_user_id in contacts_to_update {
4563        update_user_contacts(contact_user_id, session).await?;
4564    }
4565
4566    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4567        live_kit
4568            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4569            .await
4570            .trace_err();
4571
4572        if delete_livekit_room {
4573            live_kit.delete_room(livekit_room).await.trace_err();
4574        }
4575    }
4576
4577    Ok(())
4578}
4579
4580async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4581    let left_channel_buffers = session
4582        .db()
4583        .await
4584        .leave_channel_buffers(session.connection_id)
4585        .await?;
4586
4587    for left_buffer in left_channel_buffers {
4588        channel_buffer_updated(
4589            session.connection_id,
4590            left_buffer.connections,
4591            &proto::UpdateChannelBufferCollaborators {
4592                channel_id: left_buffer.channel_id.to_proto(),
4593                collaborators: left_buffer.collaborators,
4594            },
4595            &session.peer,
4596        );
4597    }
4598
4599    Ok(())
4600}
4601
4602fn project_left(project: &db::LeftProject, session: &Session) {
4603    for connection_id in &project.connection_ids {
4604        if project.should_unshare {
4605            session
4606                .peer
4607                .send(
4608                    *connection_id,
4609                    proto::UnshareProject {
4610                        project_id: project.id.to_proto(),
4611                    },
4612                )
4613                .trace_err();
4614        } else {
4615            session
4616                .peer
4617                .send(
4618                    *connection_id,
4619                    proto::RemoveProjectCollaborator {
4620                        project_id: project.id.to_proto(),
4621                        peer_id: Some(session.connection_id.into()),
4622                    },
4623                )
4624                .trace_err();
4625        }
4626    }
4627}
4628
4629pub trait ResultExt {
4630    type Ok;
4631
4632    fn trace_err(self) -> Option<Self::Ok>;
4633}
4634
4635impl<T, E> ResultExt for Result<T, E>
4636where
4637    E: std::fmt::Debug,
4638{
4639    type Ok = T;
4640
4641    #[track_caller]
4642    fn trace_err(self) -> Option<T> {
4643        match self {
4644            Ok(value) => Some(value),
4645            Err(error) => {
4646                tracing::error!("{:?}", error);
4647                None
4648            }
4649        }
4650    }
4651}