rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{
   8    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
   9    MIN_ACCOUNT_AGE_FOR_LLM_USE,
  10};
  11use crate::stripe_client::StripeCustomerId;
  12use crate::{
  13    AppState, Error, Result, auth,
  14    db::{
  15        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  16        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  17        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  18        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  19    },
  20    executor::Executor,
  21};
  22use anyhow::{Context as _, anyhow, bail};
  23use async_tungstenite::tungstenite::{
  24    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  25};
  26use axum::{
  27    Extension, Router, TypedHeader,
  28    body::Body,
  29    extract::{
  30        ConnectInfo, WebSocketUpgrade,
  31        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  32    },
  33    headers::{Header, HeaderName},
  34    http::StatusCode,
  35    middleware,
  36    response::IntoResponse,
  37    routing::get,
  38};
  39use chrono::Utc;
  40use collections::{HashMap, HashSet};
  41pub use connection_pool::{ConnectionPool, ZedVersion};
  42use core::fmt::{self, Debug, Formatter};
  43use reqwest_client::ReqwestClient;
  44use rpc::proto::split_repository_update;
  45use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  46
  47use futures::{
  48    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  49    stream::FuturesUnordered,
  50};
  51use prometheus::{IntGauge, register_int_gauge};
  52use rpc::{
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54    proto::{
  55        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  56        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  57    },
  58};
  59use semantic_version::SemanticVersion;
  60use serde::{Serialize, Serializer};
  61use std::{
  62    any::TypeId,
  63    future::Future,
  64    marker::PhantomData,
  65    mem,
  66    net::SocketAddr,
  67    ops::{Deref, DerefMut},
  68    rc::Rc,
  69    sync::{
  70        Arc, OnceLock,
  71        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  72    },
  73    time::{Duration, Instant},
  74};
  75use time::OffsetDateTime;
  76use tokio::sync::{Semaphore, watch};
  77use tower::ServiceBuilder;
  78use tracing::{
  79    Instrument,
  80    field::{self},
  81    info_span, instrument,
  82};
  83
  84pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  85
  86// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  87pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  88
  89const MESSAGE_COUNT_PER_PAGE: usize = 100;
  90const MAX_MESSAGE_LEN: usize = 1024;
  91const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  92const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  93
  94static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  95
  96type MessageHandler =
  97    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  98
  99pub struct ConnectionGuard;
 100
 101impl ConnectionGuard {
 102    pub fn try_acquire() -> Result<Self, ()> {
 103        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 104        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 105            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 106            tracing::error!(
 107                "too many concurrent connections: {}",
 108                current_connections + 1
 109            );
 110            return Err(());
 111        }
 112        Ok(ConnectionGuard)
 113    }
 114}
 115
 116impl Drop for ConnectionGuard {
 117    fn drop(&mut self) {
 118        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 119    }
 120}
 121
 122struct Response<R> {
 123    peer: Arc<Peer>,
 124    receipt: Receipt<R>,
 125    responded: Arc<AtomicBool>,
 126}
 127
 128impl<R: RequestMessage> Response<R> {
 129    fn send(self, payload: R::Response) -> Result<()> {
 130        self.responded.store(true, SeqCst);
 131        self.peer.respond(self.receipt, payload)?;
 132        Ok(())
 133    }
 134}
 135
 136#[derive(Clone, Debug)]
 137pub enum Principal {
 138    User(User),
 139    Impersonated { user: User, admin: User },
 140}
 141
 142impl Principal {
 143    fn user(&self) -> &User {
 144        match self {
 145            Principal::User(user) => user,
 146            Principal::Impersonated { user, .. } => user,
 147        }
 148    }
 149
 150    fn update_span(&self, span: &tracing::Span) {
 151        match &self {
 152            Principal::User(user) => {
 153                span.record("user_id", user.id.0);
 154                span.record("login", &user.github_login);
 155            }
 156            Principal::Impersonated { user, admin } => {
 157                span.record("user_id", user.id.0);
 158                span.record("login", &user.github_login);
 159                span.record("impersonator", &admin.github_login);
 160            }
 161        }
 162    }
 163}
 164
 165#[derive(Clone)]
 166struct Session {
 167    principal: Principal,
 168    connection_id: ConnectionId,
 169    db: Arc<tokio::sync::Mutex<DbHandle>>,
 170    peer: Arc<Peer>,
 171    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 172    app_state: Arc<AppState>,
 173    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 174    /// The GeoIP country code for the user.
 175    #[allow(unused)]
 176    geoip_country_code: Option<String>,
 177    system_id: Option<String>,
 178    _executor: Executor,
 179}
 180
 181impl Session {
 182    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 183        #[cfg(test)]
 184        tokio::task::yield_now().await;
 185        let guard = self.db.lock().await;
 186        #[cfg(test)]
 187        tokio::task::yield_now().await;
 188        guard
 189    }
 190
 191    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 192        #[cfg(test)]
 193        tokio::task::yield_now().await;
 194        let guard = self.connection_pool.lock();
 195        ConnectionPoolGuard {
 196            guard,
 197            _not_send: PhantomData,
 198        }
 199    }
 200
 201    fn is_staff(&self) -> bool {
 202        match &self.principal {
 203            Principal::User(user) => user.admin,
 204            Principal::Impersonated { .. } => true,
 205        }
 206    }
 207
 208    fn user_id(&self) -> UserId {
 209        match &self.principal {
 210            Principal::User(user) => user.id,
 211            Principal::Impersonated { user, .. } => user.id,
 212        }
 213    }
 214
 215    pub fn email(&self) -> Option<String> {
 216        match &self.principal {
 217            Principal::User(user) => user.email_address.clone(),
 218            Principal::Impersonated { user, .. } => user.email_address.clone(),
 219        }
 220    }
 221}
 222
 223impl Debug for Session {
 224    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 225        let mut result = f.debug_struct("Session");
 226        match &self.principal {
 227            Principal::User(user) => {
 228                result.field("user", &user.github_login);
 229            }
 230            Principal::Impersonated { user, admin } => {
 231                result.field("user", &user.github_login);
 232                result.field("impersonator", &admin.github_login);
 233            }
 234        }
 235        result.field("connection_id", &self.connection_id).finish()
 236    }
 237}
 238
 239struct DbHandle(Arc<Database>);
 240
 241impl Deref for DbHandle {
 242    type Target = Database;
 243
 244    fn deref(&self) -> &Self::Target {
 245        self.0.as_ref()
 246    }
 247}
 248
 249pub struct Server {
 250    id: parking_lot::Mutex<ServerId>,
 251    peer: Arc<Peer>,
 252    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 253    app_state: Arc<AppState>,
 254    handlers: HashMap<TypeId, MessageHandler>,
 255    teardown: watch::Sender<bool>,
 256}
 257
 258pub(crate) struct ConnectionPoolGuard<'a> {
 259    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 260    _not_send: PhantomData<Rc<()>>,
 261}
 262
 263#[derive(Serialize)]
 264pub struct ServerSnapshot<'a> {
 265    peer: &'a Peer,
 266    #[serde(serialize_with = "serialize_deref")]
 267    connection_pool: ConnectionPoolGuard<'a>,
 268}
 269
 270pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 271where
 272    S: Serializer,
 273    T: Deref<Target = U>,
 274    U: Serialize,
 275{
 276    Serialize::serialize(value.deref(), serializer)
 277}
 278
 279impl Server {
 280    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 281        let mut server = Self {
 282            id: parking_lot::Mutex::new(id),
 283            peer: Peer::new(id.0 as u32),
 284            app_state: app_state.clone(),
 285            connection_pool: Default::default(),
 286            handlers: Default::default(),
 287            teardown: watch::channel(false).0,
 288        };
 289
 290        server
 291            .add_request_handler(ping)
 292            .add_request_handler(create_room)
 293            .add_request_handler(join_room)
 294            .add_request_handler(rejoin_room)
 295            .add_request_handler(leave_room)
 296            .add_request_handler(set_room_participant_role)
 297            .add_request_handler(call)
 298            .add_request_handler(cancel_call)
 299            .add_message_handler(decline_call)
 300            .add_request_handler(update_participant_location)
 301            .add_request_handler(share_project)
 302            .add_message_handler(unshare_project)
 303            .add_request_handler(join_project)
 304            .add_message_handler(leave_project)
 305            .add_request_handler(update_project)
 306            .add_request_handler(update_worktree)
 307            .add_request_handler(update_repository)
 308            .add_request_handler(remove_repository)
 309            .add_message_handler(start_language_server)
 310            .add_message_handler(update_language_server)
 311            .add_message_handler(update_diagnostic_summary)
 312            .add_message_handler(update_worktree_settings)
 313            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 314            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 315            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 316            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 317            .add_request_handler(forward_find_search_candidates_request)
 318            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 319            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 320            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 321            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 322            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 323            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 324            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 325            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 326            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 327            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 328            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 329            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 330            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 331            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 332            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 333            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 334            .add_request_handler(
 335                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 336            )
 337            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 338            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 339            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 340            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 341            .add_request_handler(
 342                forward_read_only_project_request::<proto::LanguageServerIdForName>,
 343            )
 344            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 345            .add_request_handler(
 346                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 347            )
 348            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 349            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 350            .add_request_handler(
 351                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 352            )
 353            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 354            .add_request_handler(
 355                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 356            )
 357            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 358            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 359            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 360            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 361            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 362            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 363            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 364            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 365            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 366            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 367            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 368            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 369            .add_request_handler(
 370                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 371            )
 372            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 373            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 374            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 375            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 376            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 377            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 378            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 379            .add_message_handler(create_buffer_for_peer)
 380            .add_request_handler(update_buffer)
 381            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 382            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 383            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 384            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 385            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 386            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 387            .add_message_handler(
 388                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 389            )
 390            .add_request_handler(get_users)
 391            .add_request_handler(fuzzy_search_users)
 392            .add_request_handler(request_contact)
 393            .add_request_handler(remove_contact)
 394            .add_request_handler(respond_to_contact_request)
 395            .add_message_handler(subscribe_to_channels)
 396            .add_request_handler(create_channel)
 397            .add_request_handler(delete_channel)
 398            .add_request_handler(invite_channel_member)
 399            .add_request_handler(remove_channel_member)
 400            .add_request_handler(set_channel_member_role)
 401            .add_request_handler(set_channel_visibility)
 402            .add_request_handler(rename_channel)
 403            .add_request_handler(join_channel_buffer)
 404            .add_request_handler(leave_channel_buffer)
 405            .add_message_handler(update_channel_buffer)
 406            .add_request_handler(rejoin_channel_buffers)
 407            .add_request_handler(get_channel_members)
 408            .add_request_handler(respond_to_channel_invite)
 409            .add_request_handler(join_channel)
 410            .add_request_handler(join_channel_chat)
 411            .add_message_handler(leave_channel_chat)
 412            .add_request_handler(send_channel_message)
 413            .add_request_handler(remove_channel_message)
 414            .add_request_handler(update_channel_message)
 415            .add_request_handler(get_channel_messages)
 416            .add_request_handler(get_channel_messages_by_id)
 417            .add_request_handler(get_notifications)
 418            .add_request_handler(mark_notification_as_read)
 419            .add_request_handler(move_channel)
 420            .add_request_handler(reorder_channel)
 421            .add_request_handler(follow)
 422            .add_message_handler(unfollow)
 423            .add_message_handler(update_followers)
 424            .add_request_handler(get_private_user_info)
 425            .add_request_handler(get_llm_api_token)
 426            .add_request_handler(accept_terms_of_service)
 427            .add_message_handler(acknowledge_channel_message)
 428            .add_message_handler(acknowledge_buffer_version)
 429            .add_request_handler(get_supermaven_api_key)
 430            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 431            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 432            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 433            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 434            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 435            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 436            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 437            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 438            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 439            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 440            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 441            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 442            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 443            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 444            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 445            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 446            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 447            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 448            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 449            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 450            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 451            .add_message_handler(update_context);
 452
 453        Arc::new(server)
 454    }
 455
 456    pub async fn start(&self) -> Result<()> {
 457        let server_id = *self.id.lock();
 458        let app_state = self.app_state.clone();
 459        let peer = self.peer.clone();
 460        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 461        let pool = self.connection_pool.clone();
 462        let livekit_client = self.app_state.livekit_client.clone();
 463
 464        let span = info_span!("start server");
 465        self.app_state.executor.spawn_detached(
 466            async move {
 467                tracing::info!("waiting for cleanup timeout");
 468                timeout.await;
 469                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 470
 471                app_state
 472                    .db
 473                    .delete_stale_channel_chat_participants(
 474                        &app_state.config.zed_environment,
 475                        server_id,
 476                    )
 477                    .await
 478                    .trace_err();
 479
 480                if let Some((room_ids, channel_ids)) = app_state
 481                    .db
 482                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 483                    .await
 484                    .trace_err()
 485                {
 486                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 487                    tracing::info!(
 488                        stale_channel_buffer_count = channel_ids.len(),
 489                        "retrieved stale channel buffers"
 490                    );
 491
 492                    for channel_id in channel_ids {
 493                        if let Some(refreshed_channel_buffer) = app_state
 494                            .db
 495                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 496                            .await
 497                            .trace_err()
 498                        {
 499                            for connection_id in refreshed_channel_buffer.connection_ids {
 500                                peer.send(
 501                                    connection_id,
 502                                    proto::UpdateChannelBufferCollaborators {
 503                                        channel_id: channel_id.to_proto(),
 504                                        collaborators: refreshed_channel_buffer
 505                                            .collaborators
 506                                            .clone(),
 507                                    },
 508                                )
 509                                .trace_err();
 510                            }
 511                        }
 512                    }
 513
 514                    for room_id in room_ids {
 515                        let mut contacts_to_update = HashSet::default();
 516                        let mut canceled_calls_to_user_ids = Vec::new();
 517                        let mut livekit_room = String::new();
 518                        let mut delete_livekit_room = false;
 519
 520                        if let Some(mut refreshed_room) = app_state
 521                            .db
 522                            .clear_stale_room_participants(room_id, server_id)
 523                            .await
 524                            .trace_err()
 525                        {
 526                            tracing::info!(
 527                                room_id = room_id.0,
 528                                new_participant_count = refreshed_room.room.participants.len(),
 529                                "refreshed room"
 530                            );
 531                            room_updated(&refreshed_room.room, &peer);
 532                            if let Some(channel) = refreshed_room.channel.as_ref() {
 533                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 534                            }
 535                            contacts_to_update
 536                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 537                            contacts_to_update
 538                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 539                            canceled_calls_to_user_ids =
 540                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 541                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 542                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 543                        }
 544
 545                        {
 546                            let pool = pool.lock();
 547                            for canceled_user_id in canceled_calls_to_user_ids {
 548                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 549                                    peer.send(
 550                                        connection_id,
 551                                        proto::CallCanceled {
 552                                            room_id: room_id.to_proto(),
 553                                        },
 554                                    )
 555                                    .trace_err();
 556                                }
 557                            }
 558                        }
 559
 560                        for user_id in contacts_to_update {
 561                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 562                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 563                            if let Some((busy, contacts)) = busy.zip(contacts) {
 564                                let pool = pool.lock();
 565                                let updated_contact = contact_for_user(user_id, busy, &pool);
 566                                for contact in contacts {
 567                                    if let db::Contact::Accepted {
 568                                        user_id: contact_user_id,
 569                                        ..
 570                                    } = contact
 571                                    {
 572                                        for contact_conn_id in
 573                                            pool.user_connection_ids(contact_user_id)
 574                                        {
 575                                            peer.send(
 576                                                contact_conn_id,
 577                                                proto::UpdateContacts {
 578                                                    contacts: vec![updated_contact.clone()],
 579                                                    remove_contacts: Default::default(),
 580                                                    incoming_requests: Default::default(),
 581                                                    remove_incoming_requests: Default::default(),
 582                                                    outgoing_requests: Default::default(),
 583                                                    remove_outgoing_requests: Default::default(),
 584                                                },
 585                                            )
 586                                            .trace_err();
 587                                        }
 588                                    }
 589                                }
 590                            }
 591                        }
 592
 593                        if let Some(live_kit) = livekit_client.as_ref() {
 594                            if delete_livekit_room {
 595                                live_kit.delete_room(livekit_room).await.trace_err();
 596                            }
 597                        }
 598                    }
 599                }
 600
 601                app_state
 602                    .db
 603                    .delete_stale_channel_chat_participants(
 604                        &app_state.config.zed_environment,
 605                        server_id,
 606                    )
 607                    .await
 608                    .trace_err();
 609
 610                app_state
 611                    .db
 612                    .clear_old_worktree_entries(server_id)
 613                    .await
 614                    .trace_err();
 615
 616                app_state
 617                    .db
 618                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 619                    .await
 620                    .trace_err();
 621            }
 622            .instrument(span),
 623        );
 624        Ok(())
 625    }
 626
 627    pub fn teardown(&self) {
 628        self.peer.teardown();
 629        self.connection_pool.lock().reset();
 630        let _ = self.teardown.send(true);
 631    }
 632
 633    #[cfg(test)]
 634    pub fn reset(&self, id: ServerId) {
 635        self.teardown();
 636        *self.id.lock() = id;
 637        self.peer.reset(id.0 as u32);
 638        let _ = self.teardown.send(false);
 639    }
 640
 641    #[cfg(test)]
 642    pub fn id(&self) -> ServerId {
 643        *self.id.lock()
 644    }
 645
 646    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 647    where
 648        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 649        Fut: 'static + Send + Future<Output = Result<()>>,
 650        M: EnvelopedMessage,
 651    {
 652        let prev_handler = self.handlers.insert(
 653            TypeId::of::<M>(),
 654            Box::new(move |envelope, session| {
 655                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 656                let received_at = envelope.received_at;
 657                tracing::info!("message received");
 658                let start_time = Instant::now();
 659                let future = (handler)(*envelope, session);
 660                async move {
 661                    let result = future.await;
 662                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 663                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 664                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 665                    let payload_type = M::NAME;
 666
 667                    match result {
 668                        Err(error) => {
 669                            tracing::error!(
 670                                ?error,
 671                                total_duration_ms,
 672                                processing_duration_ms,
 673                                queue_duration_ms,
 674                                payload_type,
 675                                "error handling message"
 676                            )
 677                        }
 678                        Ok(()) => tracing::info!(
 679                            total_duration_ms,
 680                            processing_duration_ms,
 681                            queue_duration_ms,
 682                            "finished handling message"
 683                        ),
 684                    }
 685                }
 686                .boxed()
 687            }),
 688        );
 689        if prev_handler.is_some() {
 690            panic!("registered a handler for the same message twice");
 691        }
 692        self
 693    }
 694
 695    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 696    where
 697        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 698        Fut: 'static + Send + Future<Output = Result<()>>,
 699        M: EnvelopedMessage,
 700    {
 701        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 702        self
 703    }
 704
 705    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 706    where
 707        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 708        Fut: Send + Future<Output = Result<()>>,
 709        M: RequestMessage,
 710    {
 711        let handler = Arc::new(handler);
 712        self.add_handler(move |envelope, session| {
 713            let receipt = envelope.receipt();
 714            let handler = handler.clone();
 715            async move {
 716                let peer = session.peer.clone();
 717                let responded = Arc::new(AtomicBool::default());
 718                let response = Response {
 719                    peer: peer.clone(),
 720                    responded: responded.clone(),
 721                    receipt,
 722                };
 723                match (handler)(envelope.payload, response, session).await {
 724                    Ok(()) => {
 725                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 726                            Ok(())
 727                        } else {
 728                            Err(anyhow!("handler did not send a response"))?
 729                        }
 730                    }
 731                    Err(error) => {
 732                        let proto_err = match &error {
 733                            Error::Internal(err) => err.to_proto(),
 734                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 735                        };
 736                        peer.respond_with_error(receipt, proto_err)?;
 737                        Err(error)
 738                    }
 739                }
 740            }
 741        })
 742    }
 743
 744    pub fn handle_connection(
 745        self: &Arc<Self>,
 746        connection: Connection,
 747        address: String,
 748        principal: Principal,
 749        zed_version: ZedVersion,
 750        geoip_country_code: Option<String>,
 751        system_id: Option<String>,
 752        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 753        executor: Executor,
 754        connection_guard: Option<ConnectionGuard>,
 755    ) -> impl Future<Output = ()> + use<> {
 756        let this = self.clone();
 757        let span = info_span!("handle connection", %address,
 758            connection_id=field::Empty,
 759            user_id=field::Empty,
 760            login=field::Empty,
 761            impersonator=field::Empty,
 762            geoip_country_code=field::Empty
 763        );
 764        principal.update_span(&span);
 765        if let Some(country_code) = geoip_country_code.as_ref() {
 766            span.record("geoip_country_code", country_code);
 767        }
 768
 769        let mut teardown = self.teardown.subscribe();
 770        async move {
 771            if *teardown.borrow() {
 772                tracing::error!("server is tearing down");
 773                return
 774            }
 775
 776            let (connection_id, handle_io, mut incoming_rx) = this
 777                .peer
 778                .add_connection(connection, {
 779                    let executor = executor.clone();
 780                    move |duration| executor.sleep(duration)
 781                });
 782            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 783
 784            tracing::info!("connection opened");
 785
 786            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 787            let http_client = match ReqwestClient::user_agent(&user_agent) {
 788                Ok(http_client) => Arc::new(http_client),
 789                Err(error) => {
 790                    tracing::error!(?error, "failed to create HTTP client");
 791                    return;
 792                }
 793            };
 794
 795            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 796                    supermaven_admin_api_key.to_string(),
 797                    http_client.clone(),
 798                )));
 799
 800            let session = Session {
 801                principal: principal.clone(),
 802                connection_id,
 803                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 804                peer: this.peer.clone(),
 805                connection_pool: this.connection_pool.clone(),
 806                app_state: this.app_state.clone(),
 807                geoip_country_code,
 808                system_id,
 809                _executor: executor.clone(),
 810                supermaven_client,
 811            };
 812
 813            if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
 814                tracing::error!(?error, "failed to send initial client update");
 815                return;
 816            }
 817            drop(connection_guard);
 818
 819            let handle_io = handle_io.fuse();
 820            futures::pin_mut!(handle_io);
 821
 822            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 823            // This prevents deadlocks when e.g., client A performs a request to client B and
 824            // client B performs a request to client A. If both clients stop processing further
 825            // messages until their respective request completes, they won't have a chance to
 826            // respond to the other client's request and cause a deadlock.
 827            //
 828            // This arrangement ensures we will attempt to process earlier messages first, but fall
 829            // back to processing messages arrived later in the spirit of making progress.
 830            let mut foreground_message_handlers = FuturesUnordered::new();
 831            let concurrent_handlers = Arc::new(Semaphore::new(256));
 832            loop {
 833                let next_message = async {
 834                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 835                    let message = incoming_rx.next().await;
 836                    (permit, message)
 837                }.fuse();
 838                futures::pin_mut!(next_message);
 839                futures::select_biased! {
 840                    _ = teardown.changed().fuse() => return,
 841                    result = handle_io => {
 842                        if let Err(error) = result {
 843                            tracing::error!(?error, "error handling I/O");
 844                        }
 845                        break;
 846                    }
 847                    _ = foreground_message_handlers.next() => {}
 848                    next_message = next_message => {
 849                        let (permit, message) = next_message;
 850                        if let Some(message) = message {
 851                            let type_name = message.payload_type_name();
 852                            // note: we copy all the fields from the parent span so we can query them in the logs.
 853                            // (https://github.com/tokio-rs/tracing/issues/2670).
 854                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 855                                user_id=field::Empty,
 856                                login=field::Empty,
 857                                impersonator=field::Empty,
 858                            );
 859                            principal.update_span(&span);
 860                            let span_enter = span.enter();
 861                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 862                                let is_background = message.is_background();
 863                                let handle_message = (handler)(message, session.clone());
 864                                drop(span_enter);
 865
 866                                let handle_message = async move {
 867                                    handle_message.await;
 868                                    drop(permit);
 869                                }.instrument(span);
 870                                if is_background {
 871                                    executor.spawn_detached(handle_message);
 872                                } else {
 873                                    foreground_message_handlers.push(handle_message);
 874                                }
 875                            } else {
 876                                tracing::error!("no message handler");
 877                            }
 878                        } else {
 879                            tracing::info!("connection closed");
 880                            break;
 881                        }
 882                    }
 883                }
 884            }
 885
 886            drop(foreground_message_handlers);
 887            tracing::info!("signing out");
 888            if let Err(error) = connection_lost(session, teardown, executor).await {
 889                tracing::error!(?error, "error signing out");
 890            }
 891
 892        }.instrument(span)
 893    }
 894
 895    async fn send_initial_client_update(
 896        &self,
 897        connection_id: ConnectionId,
 898        zed_version: ZedVersion,
 899        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 900        session: &Session,
 901    ) -> Result<()> {
 902        self.peer.send(
 903            connection_id,
 904            proto::Hello {
 905                peer_id: Some(connection_id.into()),
 906            },
 907        )?;
 908        tracing::info!("sent hello message");
 909        if let Some(send_connection_id) = send_connection_id.take() {
 910            let _ = send_connection_id.send(connection_id);
 911        }
 912
 913        match &session.principal {
 914            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 915                if !user.connected_once {
 916                    self.peer.send(connection_id, proto::ShowContacts {})?;
 917                    self.app_state
 918                        .db
 919                        .set_user_connected_once(user.id, true)
 920                        .await?;
 921                }
 922
 923                update_user_plan(session).await?;
 924
 925                let contacts = self.app_state.db.get_contacts(user.id).await?;
 926
 927                {
 928                    let mut pool = self.connection_pool.lock();
 929                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 930                    self.peer.send(
 931                        connection_id,
 932                        build_initial_contacts_update(contacts, &pool),
 933                    )?;
 934                }
 935
 936                if should_auto_subscribe_to_channels(zed_version) {
 937                    subscribe_user_to_channels(user.id, session).await?;
 938                }
 939
 940                if let Some(incoming_call) =
 941                    self.app_state.db.incoming_call_for_user(user.id).await?
 942                {
 943                    self.peer.send(connection_id, incoming_call)?;
 944                }
 945
 946                update_user_contacts(user.id, session).await?;
 947            }
 948        }
 949
 950        Ok(())
 951    }
 952
 953    pub async fn invite_code_redeemed(
 954        self: &Arc<Self>,
 955        inviter_id: UserId,
 956        invitee_id: UserId,
 957    ) -> Result<()> {
 958        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 959            if let Some(code) = &user.invite_code {
 960                let pool = self.connection_pool.lock();
 961                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 962                for connection_id in pool.user_connection_ids(inviter_id) {
 963                    self.peer.send(
 964                        connection_id,
 965                        proto::UpdateContacts {
 966                            contacts: vec![invitee_contact.clone()],
 967                            ..Default::default()
 968                        },
 969                    )?;
 970                    self.peer.send(
 971                        connection_id,
 972                        proto::UpdateInviteInfo {
 973                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 974                            count: user.invite_count as u32,
 975                        },
 976                    )?;
 977                }
 978            }
 979        }
 980        Ok(())
 981    }
 982
 983    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 984        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 985            if let Some(invite_code) = &user.invite_code {
 986                let pool = self.connection_pool.lock();
 987                for connection_id in pool.user_connection_ids(user_id) {
 988                    self.peer.send(
 989                        connection_id,
 990                        proto::UpdateInviteInfo {
 991                            url: format!(
 992                                "{}{}",
 993                                self.app_state.config.invite_link_prefix, invite_code
 994                            ),
 995                            count: user.invite_count as u32,
 996                        },
 997                    )?;
 998                }
 999            }
1000        }
1001        Ok(())
1002    }
1003
1004    pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1005        let user = self
1006            .app_state
1007            .db
1008            .get_user_by_id(user_id)
1009            .await?
1010            .context("user not found")?;
1011
1012        let update_user_plan = make_update_user_plan_message(
1013            &user,
1014            user.admin,
1015            &self.app_state.db,
1016            self.app_state.llm_db.clone(),
1017        )
1018        .await?;
1019
1020        let pool = self.connection_pool.lock();
1021        for connection_id in pool.user_connection_ids(user_id) {
1022            self.peer
1023                .send(connection_id, update_user_plan.clone())
1024                .trace_err();
1025        }
1026
1027        Ok(())
1028    }
1029
1030    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1031        let pool = self.connection_pool.lock();
1032        for connection_id in pool.user_connection_ids(user_id) {
1033            self.peer
1034                .send(connection_id, proto::RefreshLlmToken {})
1035                .trace_err();
1036        }
1037    }
1038
1039    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
1040        ServerSnapshot {
1041            connection_pool: ConnectionPoolGuard {
1042                guard: self.connection_pool.lock(),
1043                _not_send: PhantomData,
1044            },
1045            peer: &self.peer,
1046        }
1047    }
1048}
1049
1050impl Deref for ConnectionPoolGuard<'_> {
1051    type Target = ConnectionPool;
1052
1053    fn deref(&self) -> &Self::Target {
1054        &self.guard
1055    }
1056}
1057
1058impl DerefMut for ConnectionPoolGuard<'_> {
1059    fn deref_mut(&mut self) -> &mut Self::Target {
1060        &mut self.guard
1061    }
1062}
1063
1064impl Drop for ConnectionPoolGuard<'_> {
1065    fn drop(&mut self) {
1066        #[cfg(test)]
1067        self.check_invariants();
1068    }
1069}
1070
1071fn broadcast<F>(
1072    sender_id: Option<ConnectionId>,
1073    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1074    mut f: F,
1075) where
1076    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1077{
1078    for receiver_id in receiver_ids {
1079        if Some(receiver_id) != sender_id {
1080            if let Err(error) = f(receiver_id) {
1081                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1082            }
1083        }
1084    }
1085}
1086
1087pub struct ProtocolVersion(u32);
1088
1089impl Header for ProtocolVersion {
1090    fn name() -> &'static HeaderName {
1091        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1092        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1093    }
1094
1095    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1096    where
1097        Self: Sized,
1098        I: Iterator<Item = &'i axum::http::HeaderValue>,
1099    {
1100        let version = values
1101            .next()
1102            .ok_or_else(axum::headers::Error::invalid)?
1103            .to_str()
1104            .map_err(|_| axum::headers::Error::invalid())?
1105            .parse()
1106            .map_err(|_| axum::headers::Error::invalid())?;
1107        Ok(Self(version))
1108    }
1109
1110    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1111        values.extend([self.0.to_string().parse().unwrap()]);
1112    }
1113}
1114
1115pub struct AppVersionHeader(SemanticVersion);
1116impl Header for AppVersionHeader {
1117    fn name() -> &'static HeaderName {
1118        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1119        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1120    }
1121
1122    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1123    where
1124        Self: Sized,
1125        I: Iterator<Item = &'i axum::http::HeaderValue>,
1126    {
1127        let version = values
1128            .next()
1129            .ok_or_else(axum::headers::Error::invalid)?
1130            .to_str()
1131            .map_err(|_| axum::headers::Error::invalid())?
1132            .parse()
1133            .map_err(|_| axum::headers::Error::invalid())?;
1134        Ok(Self(version))
1135    }
1136
1137    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1138        values.extend([self.0.to_string().parse().unwrap()]);
1139    }
1140}
1141
1142pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1143    Router::new()
1144        .route("/rpc", get(handle_websocket_request))
1145        .layer(
1146            ServiceBuilder::new()
1147                .layer(Extension(server.app_state.clone()))
1148                .layer(middleware::from_fn(auth::validate_header)),
1149        )
1150        .route("/metrics", get(handle_metrics))
1151        .layer(Extension(server))
1152}
1153
1154pub async fn handle_websocket_request(
1155    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1156    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1157    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1158    Extension(server): Extension<Arc<Server>>,
1159    Extension(principal): Extension<Principal>,
1160    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1161    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1162    ws: WebSocketUpgrade,
1163) -> axum::response::Response {
1164    if protocol_version != rpc::PROTOCOL_VERSION {
1165        return (
1166            StatusCode::UPGRADE_REQUIRED,
1167            "client must be upgraded".to_string(),
1168        )
1169            .into_response();
1170    }
1171
1172    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1173        return (
1174            StatusCode::UPGRADE_REQUIRED,
1175            "no version header found".to_string(),
1176        )
1177            .into_response();
1178    };
1179
1180    if !version.can_collaborate() {
1181        return (
1182            StatusCode::UPGRADE_REQUIRED,
1183            "client must be upgraded".to_string(),
1184        )
1185            .into_response();
1186    }
1187
1188    let socket_address = socket_address.to_string();
1189
1190    // Acquire connection guard before WebSocket upgrade
1191    let connection_guard = match ConnectionGuard::try_acquire() {
1192        Ok(guard) => guard,
1193        Err(()) => {
1194            return (
1195                StatusCode::SERVICE_UNAVAILABLE,
1196                "Too many concurrent connections",
1197            )
1198                .into_response();
1199        }
1200    };
1201
1202    ws.on_upgrade(move |socket| {
1203        let socket = socket
1204            .map_ok(to_tungstenite_message)
1205            .err_into()
1206            .with(|message| async move { to_axum_message(message) });
1207        let connection = Connection::new(Box::pin(socket));
1208        async move {
1209            server
1210                .handle_connection(
1211                    connection,
1212                    socket_address,
1213                    principal,
1214                    version,
1215                    country_code_header.map(|header| header.to_string()),
1216                    system_id_header.map(|header| header.to_string()),
1217                    None,
1218                    Executor::Production,
1219                    Some(connection_guard),
1220                )
1221                .await;
1222        }
1223    })
1224}
1225
1226pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1227    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1228    let connections_metric = CONNECTIONS_METRIC
1229        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1230
1231    let connections = server
1232        .connection_pool
1233        .lock()
1234        .connections()
1235        .filter(|connection| !connection.admin)
1236        .count();
1237    connections_metric.set(connections as _);
1238
1239    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1240    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1241        register_int_gauge!(
1242            "shared_projects",
1243            "number of open projects with one or more guests"
1244        )
1245        .unwrap()
1246    });
1247
1248    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1249    shared_projects_metric.set(shared_projects as _);
1250
1251    let encoder = prometheus::TextEncoder::new();
1252    let metric_families = prometheus::gather();
1253    let encoded_metrics = encoder
1254        .encode_to_string(&metric_families)
1255        .map_err(|err| anyhow!("{err}"))?;
1256    Ok(encoded_metrics)
1257}
1258
1259#[instrument(err, skip(executor))]
1260async fn connection_lost(
1261    session: Session,
1262    mut teardown: watch::Receiver<bool>,
1263    executor: Executor,
1264) -> Result<()> {
1265    session.peer.disconnect(session.connection_id);
1266    session
1267        .connection_pool()
1268        .await
1269        .remove_connection(session.connection_id)?;
1270
1271    session
1272        .db()
1273        .await
1274        .connection_lost(session.connection_id)
1275        .await
1276        .trace_err();
1277
1278    futures::select_biased! {
1279        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1280
1281            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1282            leave_room_for_session(&session, session.connection_id).await.trace_err();
1283            leave_channel_buffers_for_session(&session)
1284                .await
1285                .trace_err();
1286
1287            if !session
1288                .connection_pool()
1289                .await
1290                .is_user_online(session.user_id())
1291            {
1292                let db = session.db().await;
1293                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1294                    room_updated(&room, &session.peer);
1295                }
1296            }
1297
1298            update_user_contacts(session.user_id(), &session).await?;
1299        },
1300        _ = teardown.changed().fuse() => {}
1301    }
1302
1303    Ok(())
1304}
1305
1306/// Acknowledges a ping from a client, used to keep the connection alive.
1307async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1308    response.send(proto::Ack {})?;
1309    Ok(())
1310}
1311
1312/// Creates a new room for calling (outside of channels)
1313async fn create_room(
1314    _request: proto::CreateRoom,
1315    response: Response<proto::CreateRoom>,
1316    session: Session,
1317) -> Result<()> {
1318    let livekit_room = nanoid::nanoid!(30);
1319
1320    let live_kit_connection_info = util::maybe!(async {
1321        let live_kit = session.app_state.livekit_client.as_ref();
1322        let live_kit = live_kit?;
1323        let user_id = session.user_id().to_string();
1324
1325        let token = live_kit
1326            .room_token(&livekit_room, &user_id.to_string())
1327            .trace_err()?;
1328
1329        Some(proto::LiveKitConnectionInfo {
1330            server_url: live_kit.url().into(),
1331            token,
1332            can_publish: true,
1333        })
1334    })
1335    .await;
1336
1337    let room = session
1338        .db()
1339        .await
1340        .create_room(session.user_id(), session.connection_id, &livekit_room)
1341        .await?;
1342
1343    response.send(proto::CreateRoomResponse {
1344        room: Some(room.clone()),
1345        live_kit_connection_info,
1346    })?;
1347
1348    update_user_contacts(session.user_id(), &session).await?;
1349    Ok(())
1350}
1351
1352/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1353async fn join_room(
1354    request: proto::JoinRoom,
1355    response: Response<proto::JoinRoom>,
1356    session: Session,
1357) -> Result<()> {
1358    let room_id = RoomId::from_proto(request.id);
1359
1360    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1361
1362    if let Some(channel_id) = channel_id {
1363        return join_channel_internal(channel_id, Box::new(response), session).await;
1364    }
1365
1366    let joined_room = {
1367        let room = session
1368            .db()
1369            .await
1370            .join_room(room_id, session.user_id(), session.connection_id)
1371            .await?;
1372        room_updated(&room.room, &session.peer);
1373        room.into_inner()
1374    };
1375
1376    for connection_id in session
1377        .connection_pool()
1378        .await
1379        .user_connection_ids(session.user_id())
1380    {
1381        session
1382            .peer
1383            .send(
1384                connection_id,
1385                proto::CallCanceled {
1386                    room_id: room_id.to_proto(),
1387                },
1388            )
1389            .trace_err();
1390    }
1391
1392    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1393    {
1394        live_kit
1395            .room_token(
1396                &joined_room.room.livekit_room,
1397                &session.user_id().to_string(),
1398            )
1399            .trace_err()
1400            .map(|token| proto::LiveKitConnectionInfo {
1401                server_url: live_kit.url().into(),
1402                token,
1403                can_publish: true,
1404            })
1405    } else {
1406        None
1407    };
1408
1409    response.send(proto::JoinRoomResponse {
1410        room: Some(joined_room.room),
1411        channel_id: None,
1412        live_kit_connection_info,
1413    })?;
1414
1415    update_user_contacts(session.user_id(), &session).await?;
1416    Ok(())
1417}
1418
1419/// Rejoin room is used to reconnect to a room after connection errors.
1420async fn rejoin_room(
1421    request: proto::RejoinRoom,
1422    response: Response<proto::RejoinRoom>,
1423    session: Session,
1424) -> Result<()> {
1425    let room;
1426    let channel;
1427    {
1428        let mut rejoined_room = session
1429            .db()
1430            .await
1431            .rejoin_room(request, session.user_id(), session.connection_id)
1432            .await?;
1433
1434        response.send(proto::RejoinRoomResponse {
1435            room: Some(rejoined_room.room.clone()),
1436            reshared_projects: rejoined_room
1437                .reshared_projects
1438                .iter()
1439                .map(|project| proto::ResharedProject {
1440                    id: project.id.to_proto(),
1441                    collaborators: project
1442                        .collaborators
1443                        .iter()
1444                        .map(|collaborator| collaborator.to_proto())
1445                        .collect(),
1446                })
1447                .collect(),
1448            rejoined_projects: rejoined_room
1449                .rejoined_projects
1450                .iter()
1451                .map(|rejoined_project| rejoined_project.to_proto())
1452                .collect(),
1453        })?;
1454        room_updated(&rejoined_room.room, &session.peer);
1455
1456        for project in &rejoined_room.reshared_projects {
1457            for collaborator in &project.collaborators {
1458                session
1459                    .peer
1460                    .send(
1461                        collaborator.connection_id,
1462                        proto::UpdateProjectCollaborator {
1463                            project_id: project.id.to_proto(),
1464                            old_peer_id: Some(project.old_connection_id.into()),
1465                            new_peer_id: Some(session.connection_id.into()),
1466                        },
1467                    )
1468                    .trace_err();
1469            }
1470
1471            broadcast(
1472                Some(session.connection_id),
1473                project
1474                    .collaborators
1475                    .iter()
1476                    .map(|collaborator| collaborator.connection_id),
1477                |connection_id| {
1478                    session.peer.forward_send(
1479                        session.connection_id,
1480                        connection_id,
1481                        proto::UpdateProject {
1482                            project_id: project.id.to_proto(),
1483                            worktrees: project.worktrees.clone(),
1484                        },
1485                    )
1486                },
1487            );
1488        }
1489
1490        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1491
1492        let rejoined_room = rejoined_room.into_inner();
1493
1494        room = rejoined_room.room;
1495        channel = rejoined_room.channel;
1496    }
1497
1498    if let Some(channel) = channel {
1499        channel_updated(
1500            &channel,
1501            &room,
1502            &session.peer,
1503            &*session.connection_pool().await,
1504        );
1505    }
1506
1507    update_user_contacts(session.user_id(), &session).await?;
1508    Ok(())
1509}
1510
1511fn notify_rejoined_projects(
1512    rejoined_projects: &mut Vec<RejoinedProject>,
1513    session: &Session,
1514) -> Result<()> {
1515    for project in rejoined_projects.iter() {
1516        for collaborator in &project.collaborators {
1517            session
1518                .peer
1519                .send(
1520                    collaborator.connection_id,
1521                    proto::UpdateProjectCollaborator {
1522                        project_id: project.id.to_proto(),
1523                        old_peer_id: Some(project.old_connection_id.into()),
1524                        new_peer_id: Some(session.connection_id.into()),
1525                    },
1526                )
1527                .trace_err();
1528        }
1529    }
1530
1531    for project in rejoined_projects {
1532        for worktree in mem::take(&mut project.worktrees) {
1533            // Stream this worktree's entries.
1534            let message = proto::UpdateWorktree {
1535                project_id: project.id.to_proto(),
1536                worktree_id: worktree.id,
1537                abs_path: worktree.abs_path.clone(),
1538                root_name: worktree.root_name,
1539                updated_entries: worktree.updated_entries,
1540                removed_entries: worktree.removed_entries,
1541                scan_id: worktree.scan_id,
1542                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1543                updated_repositories: worktree.updated_repositories,
1544                removed_repositories: worktree.removed_repositories,
1545            };
1546            for update in proto::split_worktree_update(message) {
1547                session.peer.send(session.connection_id, update)?;
1548            }
1549
1550            // Stream this worktree's diagnostics.
1551            for summary in worktree.diagnostic_summaries {
1552                session.peer.send(
1553                    session.connection_id,
1554                    proto::UpdateDiagnosticSummary {
1555                        project_id: project.id.to_proto(),
1556                        worktree_id: worktree.id,
1557                        summary: Some(summary),
1558                    },
1559                )?;
1560            }
1561
1562            for settings_file in worktree.settings_files {
1563                session.peer.send(
1564                    session.connection_id,
1565                    proto::UpdateWorktreeSettings {
1566                        project_id: project.id.to_proto(),
1567                        worktree_id: worktree.id,
1568                        path: settings_file.path,
1569                        content: Some(settings_file.content),
1570                        kind: Some(settings_file.kind.to_proto().into()),
1571                    },
1572                )?;
1573            }
1574        }
1575
1576        for repository in mem::take(&mut project.updated_repositories) {
1577            for update in split_repository_update(repository) {
1578                session.peer.send(session.connection_id, update)?;
1579            }
1580        }
1581
1582        for id in mem::take(&mut project.removed_repositories) {
1583            session.peer.send(
1584                session.connection_id,
1585                proto::RemoveRepository {
1586                    project_id: project.id.to_proto(),
1587                    id,
1588                },
1589            )?;
1590        }
1591    }
1592
1593    Ok(())
1594}
1595
1596/// leave room disconnects from the room.
1597async fn leave_room(
1598    _: proto::LeaveRoom,
1599    response: Response<proto::LeaveRoom>,
1600    session: Session,
1601) -> Result<()> {
1602    leave_room_for_session(&session, session.connection_id).await?;
1603    response.send(proto::Ack {})?;
1604    Ok(())
1605}
1606
1607/// Updates the permissions of someone else in the room.
1608async fn set_room_participant_role(
1609    request: proto::SetRoomParticipantRole,
1610    response: Response<proto::SetRoomParticipantRole>,
1611    session: Session,
1612) -> Result<()> {
1613    let user_id = UserId::from_proto(request.user_id);
1614    let role = ChannelRole::from(request.role());
1615
1616    let (livekit_room, can_publish) = {
1617        let room = session
1618            .db()
1619            .await
1620            .set_room_participant_role(
1621                session.user_id(),
1622                RoomId::from_proto(request.room_id),
1623                user_id,
1624                role,
1625            )
1626            .await?;
1627
1628        let livekit_room = room.livekit_room.clone();
1629        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1630        room_updated(&room, &session.peer);
1631        (livekit_room, can_publish)
1632    };
1633
1634    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1635        live_kit
1636            .update_participant(
1637                livekit_room.clone(),
1638                request.user_id.to_string(),
1639                livekit_api::proto::ParticipantPermission {
1640                    can_subscribe: true,
1641                    can_publish,
1642                    can_publish_data: can_publish,
1643                    hidden: false,
1644                    recorder: false,
1645                },
1646            )
1647            .await
1648            .trace_err();
1649    }
1650
1651    response.send(proto::Ack {})?;
1652    Ok(())
1653}
1654
1655/// Call someone else into the current room
1656async fn call(
1657    request: proto::Call,
1658    response: Response<proto::Call>,
1659    session: Session,
1660) -> Result<()> {
1661    let room_id = RoomId::from_proto(request.room_id);
1662    let calling_user_id = session.user_id();
1663    let calling_connection_id = session.connection_id;
1664    let called_user_id = UserId::from_proto(request.called_user_id);
1665    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1666    if !session
1667        .db()
1668        .await
1669        .has_contact(calling_user_id, called_user_id)
1670        .await?
1671    {
1672        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1673    }
1674
1675    let incoming_call = {
1676        let (room, incoming_call) = &mut *session
1677            .db()
1678            .await
1679            .call(
1680                room_id,
1681                calling_user_id,
1682                calling_connection_id,
1683                called_user_id,
1684                initial_project_id,
1685            )
1686            .await?;
1687        room_updated(room, &session.peer);
1688        mem::take(incoming_call)
1689    };
1690    update_user_contacts(called_user_id, &session).await?;
1691
1692    let mut calls = session
1693        .connection_pool()
1694        .await
1695        .user_connection_ids(called_user_id)
1696        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1697        .collect::<FuturesUnordered<_>>();
1698
1699    while let Some(call_response) = calls.next().await {
1700        match call_response.as_ref() {
1701            Ok(_) => {
1702                response.send(proto::Ack {})?;
1703                return Ok(());
1704            }
1705            Err(_) => {
1706                call_response.trace_err();
1707            }
1708        }
1709    }
1710
1711    {
1712        let room = session
1713            .db()
1714            .await
1715            .call_failed(room_id, called_user_id)
1716            .await?;
1717        room_updated(&room, &session.peer);
1718    }
1719    update_user_contacts(called_user_id, &session).await?;
1720
1721    Err(anyhow!("failed to ring user"))?
1722}
1723
1724/// Cancel an outgoing call.
1725async fn cancel_call(
1726    request: proto::CancelCall,
1727    response: Response<proto::CancelCall>,
1728    session: Session,
1729) -> Result<()> {
1730    let called_user_id = UserId::from_proto(request.called_user_id);
1731    let room_id = RoomId::from_proto(request.room_id);
1732    {
1733        let room = session
1734            .db()
1735            .await
1736            .cancel_call(room_id, session.connection_id, called_user_id)
1737            .await?;
1738        room_updated(&room, &session.peer);
1739    }
1740
1741    for connection_id in session
1742        .connection_pool()
1743        .await
1744        .user_connection_ids(called_user_id)
1745    {
1746        session
1747            .peer
1748            .send(
1749                connection_id,
1750                proto::CallCanceled {
1751                    room_id: room_id.to_proto(),
1752                },
1753            )
1754            .trace_err();
1755    }
1756    response.send(proto::Ack {})?;
1757
1758    update_user_contacts(called_user_id, &session).await?;
1759    Ok(())
1760}
1761
1762/// Decline an incoming call.
1763async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1764    let room_id = RoomId::from_proto(message.room_id);
1765    {
1766        let room = session
1767            .db()
1768            .await
1769            .decline_call(Some(room_id), session.user_id())
1770            .await?
1771            .context("declining call")?;
1772        room_updated(&room, &session.peer);
1773    }
1774
1775    for connection_id in session
1776        .connection_pool()
1777        .await
1778        .user_connection_ids(session.user_id())
1779    {
1780        session
1781            .peer
1782            .send(
1783                connection_id,
1784                proto::CallCanceled {
1785                    room_id: room_id.to_proto(),
1786                },
1787            )
1788            .trace_err();
1789    }
1790    update_user_contacts(session.user_id(), &session).await?;
1791    Ok(())
1792}
1793
1794/// Updates other participants in the room with your current location.
1795async fn update_participant_location(
1796    request: proto::UpdateParticipantLocation,
1797    response: Response<proto::UpdateParticipantLocation>,
1798    session: Session,
1799) -> Result<()> {
1800    let room_id = RoomId::from_proto(request.room_id);
1801    let location = request.location.context("invalid location")?;
1802
1803    let db = session.db().await;
1804    let room = db
1805        .update_room_participant_location(room_id, session.connection_id, location)
1806        .await?;
1807
1808    room_updated(&room, &session.peer);
1809    response.send(proto::Ack {})?;
1810    Ok(())
1811}
1812
1813/// Share a project into the room.
1814async fn share_project(
1815    request: proto::ShareProject,
1816    response: Response<proto::ShareProject>,
1817    session: Session,
1818) -> Result<()> {
1819    let (project_id, room) = &*session
1820        .db()
1821        .await
1822        .share_project(
1823            RoomId::from_proto(request.room_id),
1824            session.connection_id,
1825            &request.worktrees,
1826            request.is_ssh_project,
1827        )
1828        .await?;
1829    response.send(proto::ShareProjectResponse {
1830        project_id: project_id.to_proto(),
1831    })?;
1832    room_updated(room, &session.peer);
1833
1834    Ok(())
1835}
1836
1837/// Unshare a project from the room.
1838async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1839    let project_id = ProjectId::from_proto(message.project_id);
1840    unshare_project_internal(project_id, session.connection_id, &session).await
1841}
1842
1843async fn unshare_project_internal(
1844    project_id: ProjectId,
1845    connection_id: ConnectionId,
1846    session: &Session,
1847) -> Result<()> {
1848    let delete = {
1849        let room_guard = session
1850            .db()
1851            .await
1852            .unshare_project(project_id, connection_id)
1853            .await?;
1854
1855        let (delete, room, guest_connection_ids) = &*room_guard;
1856
1857        let message = proto::UnshareProject {
1858            project_id: project_id.to_proto(),
1859        };
1860
1861        broadcast(
1862            Some(connection_id),
1863            guest_connection_ids.iter().copied(),
1864            |conn_id| session.peer.send(conn_id, message.clone()),
1865        );
1866        if let Some(room) = room {
1867            room_updated(room, &session.peer);
1868        }
1869
1870        *delete
1871    };
1872
1873    if delete {
1874        let db = session.db().await;
1875        db.delete_project(project_id).await?;
1876    }
1877
1878    Ok(())
1879}
1880
1881/// Join someone elses shared project.
1882async fn join_project(
1883    request: proto::JoinProject,
1884    response: Response<proto::JoinProject>,
1885    session: Session,
1886) -> Result<()> {
1887    let project_id = ProjectId::from_proto(request.project_id);
1888
1889    tracing::info!(%project_id, "join project");
1890
1891    let db = session.db().await;
1892    let (project, replica_id) = &mut *db
1893        .join_project(project_id, session.connection_id, session.user_id())
1894        .await?;
1895    drop(db);
1896    tracing::info!(%project_id, "join remote project");
1897    join_project_internal(response, session, project, replica_id)
1898}
1899
1900trait JoinProjectInternalResponse {
1901    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1902}
1903impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1904    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1905        Response::<proto::JoinProject>::send(self, result)
1906    }
1907}
1908
1909fn join_project_internal(
1910    response: impl JoinProjectInternalResponse,
1911    session: Session,
1912    project: &mut Project,
1913    replica_id: &ReplicaId,
1914) -> Result<()> {
1915    let collaborators = project
1916        .collaborators
1917        .iter()
1918        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1919        .map(|collaborator| collaborator.to_proto())
1920        .collect::<Vec<_>>();
1921    let project_id = project.id;
1922    let guest_user_id = session.user_id();
1923
1924    let worktrees = project
1925        .worktrees
1926        .iter()
1927        .map(|(id, worktree)| proto::WorktreeMetadata {
1928            id: *id,
1929            root_name: worktree.root_name.clone(),
1930            visible: worktree.visible,
1931            abs_path: worktree.abs_path.clone(),
1932        })
1933        .collect::<Vec<_>>();
1934
1935    let add_project_collaborator = proto::AddProjectCollaborator {
1936        project_id: project_id.to_proto(),
1937        collaborator: Some(proto::Collaborator {
1938            peer_id: Some(session.connection_id.into()),
1939            replica_id: replica_id.0 as u32,
1940            user_id: guest_user_id.to_proto(),
1941            is_host: false,
1942        }),
1943    };
1944
1945    for collaborator in &collaborators {
1946        session
1947            .peer
1948            .send(
1949                collaborator.peer_id.unwrap().into(),
1950                add_project_collaborator.clone(),
1951            )
1952            .trace_err();
1953    }
1954
1955    // First, we send the metadata associated with each worktree.
1956    response.send(proto::JoinProjectResponse {
1957        project_id: project.id.0 as u64,
1958        worktrees: worktrees.clone(),
1959        replica_id: replica_id.0 as u32,
1960        collaborators: collaborators.clone(),
1961        language_servers: project.language_servers.clone(),
1962        role: project.role.into(),
1963    })?;
1964
1965    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1966        // Stream this worktree's entries.
1967        let message = proto::UpdateWorktree {
1968            project_id: project_id.to_proto(),
1969            worktree_id,
1970            abs_path: worktree.abs_path.clone(),
1971            root_name: worktree.root_name,
1972            updated_entries: worktree.entries,
1973            removed_entries: Default::default(),
1974            scan_id: worktree.scan_id,
1975            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1976            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1977            removed_repositories: Default::default(),
1978        };
1979        for update in proto::split_worktree_update(message) {
1980            session.peer.send(session.connection_id, update.clone())?;
1981        }
1982
1983        // Stream this worktree's diagnostics.
1984        for summary in worktree.diagnostic_summaries {
1985            session.peer.send(
1986                session.connection_id,
1987                proto::UpdateDiagnosticSummary {
1988                    project_id: project_id.to_proto(),
1989                    worktree_id: worktree.id,
1990                    summary: Some(summary),
1991                },
1992            )?;
1993        }
1994
1995        for settings_file in worktree.settings_files {
1996            session.peer.send(
1997                session.connection_id,
1998                proto::UpdateWorktreeSettings {
1999                    project_id: project_id.to_proto(),
2000                    worktree_id: worktree.id,
2001                    path: settings_file.path,
2002                    content: Some(settings_file.content),
2003                    kind: Some(settings_file.kind.to_proto() as i32),
2004                },
2005            )?;
2006        }
2007    }
2008
2009    for repository in mem::take(&mut project.repositories) {
2010        for update in split_repository_update(repository) {
2011            session.peer.send(session.connection_id, update)?;
2012        }
2013    }
2014
2015    for language_server in &project.language_servers {
2016        session.peer.send(
2017            session.connection_id,
2018            proto::UpdateLanguageServer {
2019                project_id: project_id.to_proto(),
2020                language_server_id: language_server.id,
2021                variant: Some(
2022                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2023                        proto::LspDiskBasedDiagnosticsUpdated {},
2024                    ),
2025                ),
2026            },
2027        )?;
2028    }
2029
2030    Ok(())
2031}
2032
2033/// Leave someone elses shared project.
2034async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2035    let sender_id = session.connection_id;
2036    let project_id = ProjectId::from_proto(request.project_id);
2037    let db = session.db().await;
2038
2039    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2040    tracing::info!(
2041        %project_id,
2042        "leave project"
2043    );
2044
2045    project_left(project, &session);
2046    if let Some(room) = room {
2047        room_updated(room, &session.peer);
2048    }
2049
2050    Ok(())
2051}
2052
2053/// Updates other participants with changes to the project
2054async fn update_project(
2055    request: proto::UpdateProject,
2056    response: Response<proto::UpdateProject>,
2057    session: Session,
2058) -> Result<()> {
2059    let project_id = ProjectId::from_proto(request.project_id);
2060    let (room, guest_connection_ids) = &*session
2061        .db()
2062        .await
2063        .update_project(project_id, session.connection_id, &request.worktrees)
2064        .await?;
2065    broadcast(
2066        Some(session.connection_id),
2067        guest_connection_ids.iter().copied(),
2068        |connection_id| {
2069            session
2070                .peer
2071                .forward_send(session.connection_id, connection_id, request.clone())
2072        },
2073    );
2074    if let Some(room) = room {
2075        room_updated(room, &session.peer);
2076    }
2077    response.send(proto::Ack {})?;
2078
2079    Ok(())
2080}
2081
2082/// Updates other participants with changes to the worktree
2083async fn update_worktree(
2084    request: proto::UpdateWorktree,
2085    response: Response<proto::UpdateWorktree>,
2086    session: Session,
2087) -> Result<()> {
2088    let guest_connection_ids = session
2089        .db()
2090        .await
2091        .update_worktree(&request, session.connection_id)
2092        .await?;
2093
2094    broadcast(
2095        Some(session.connection_id),
2096        guest_connection_ids.iter().copied(),
2097        |connection_id| {
2098            session
2099                .peer
2100                .forward_send(session.connection_id, connection_id, request.clone())
2101        },
2102    );
2103    response.send(proto::Ack {})?;
2104    Ok(())
2105}
2106
2107async fn update_repository(
2108    request: proto::UpdateRepository,
2109    response: Response<proto::UpdateRepository>,
2110    session: Session,
2111) -> Result<()> {
2112    let guest_connection_ids = session
2113        .db()
2114        .await
2115        .update_repository(&request, session.connection_id)
2116        .await?;
2117
2118    broadcast(
2119        Some(session.connection_id),
2120        guest_connection_ids.iter().copied(),
2121        |connection_id| {
2122            session
2123                .peer
2124                .forward_send(session.connection_id, connection_id, request.clone())
2125        },
2126    );
2127    response.send(proto::Ack {})?;
2128    Ok(())
2129}
2130
2131async fn remove_repository(
2132    request: proto::RemoveRepository,
2133    response: Response<proto::RemoveRepository>,
2134    session: Session,
2135) -> Result<()> {
2136    let guest_connection_ids = session
2137        .db()
2138        .await
2139        .remove_repository(&request, session.connection_id)
2140        .await?;
2141
2142    broadcast(
2143        Some(session.connection_id),
2144        guest_connection_ids.iter().copied(),
2145        |connection_id| {
2146            session
2147                .peer
2148                .forward_send(session.connection_id, connection_id, request.clone())
2149        },
2150    );
2151    response.send(proto::Ack {})?;
2152    Ok(())
2153}
2154
2155/// Updates other participants with changes to the diagnostics
2156async fn update_diagnostic_summary(
2157    message: proto::UpdateDiagnosticSummary,
2158    session: Session,
2159) -> Result<()> {
2160    let guest_connection_ids = session
2161        .db()
2162        .await
2163        .update_diagnostic_summary(&message, session.connection_id)
2164        .await?;
2165
2166    broadcast(
2167        Some(session.connection_id),
2168        guest_connection_ids.iter().copied(),
2169        |connection_id| {
2170            session
2171                .peer
2172                .forward_send(session.connection_id, connection_id, message.clone())
2173        },
2174    );
2175
2176    Ok(())
2177}
2178
2179/// Updates other participants with changes to the worktree settings
2180async fn update_worktree_settings(
2181    message: proto::UpdateWorktreeSettings,
2182    session: Session,
2183) -> Result<()> {
2184    let guest_connection_ids = session
2185        .db()
2186        .await
2187        .update_worktree_settings(&message, session.connection_id)
2188        .await?;
2189
2190    broadcast(
2191        Some(session.connection_id),
2192        guest_connection_ids.iter().copied(),
2193        |connection_id| {
2194            session
2195                .peer
2196                .forward_send(session.connection_id, connection_id, message.clone())
2197        },
2198    );
2199
2200    Ok(())
2201}
2202
2203/// Notify other participants that a language server has started.
2204async fn start_language_server(
2205    request: proto::StartLanguageServer,
2206    session: Session,
2207) -> Result<()> {
2208    let guest_connection_ids = session
2209        .db()
2210        .await
2211        .start_language_server(&request, session.connection_id)
2212        .await?;
2213
2214    broadcast(
2215        Some(session.connection_id),
2216        guest_connection_ids.iter().copied(),
2217        |connection_id| {
2218            session
2219                .peer
2220                .forward_send(session.connection_id, connection_id, request.clone())
2221        },
2222    );
2223    Ok(())
2224}
2225
2226/// Notify other participants that a language server has changed.
2227async fn update_language_server(
2228    request: proto::UpdateLanguageServer,
2229    session: Session,
2230) -> Result<()> {
2231    let project_id = ProjectId::from_proto(request.project_id);
2232    let project_connection_ids = session
2233        .db()
2234        .await
2235        .project_connection_ids(project_id, session.connection_id, true)
2236        .await?;
2237    broadcast(
2238        Some(session.connection_id),
2239        project_connection_ids.iter().copied(),
2240        |connection_id| {
2241            session
2242                .peer
2243                .forward_send(session.connection_id, connection_id, request.clone())
2244        },
2245    );
2246    Ok(())
2247}
2248
2249/// forward a project request to the host. These requests should be read only
2250/// as guests are allowed to send them.
2251async fn forward_read_only_project_request<T>(
2252    request: T,
2253    response: Response<T>,
2254    session: Session,
2255) -> Result<()>
2256where
2257    T: EntityMessage + RequestMessage,
2258{
2259    let project_id = ProjectId::from_proto(request.remote_entity_id());
2260    let host_connection_id = session
2261        .db()
2262        .await
2263        .host_for_read_only_project_request(project_id, session.connection_id)
2264        .await?;
2265    let payload = session
2266        .peer
2267        .forward_request(session.connection_id, host_connection_id, request)
2268        .await?;
2269    response.send(payload)?;
2270    Ok(())
2271}
2272
2273async fn forward_find_search_candidates_request(
2274    request: proto::FindSearchCandidates,
2275    response: Response<proto::FindSearchCandidates>,
2276    session: Session,
2277) -> Result<()> {
2278    let project_id = ProjectId::from_proto(request.remote_entity_id());
2279    let host_connection_id = session
2280        .db()
2281        .await
2282        .host_for_read_only_project_request(project_id, session.connection_id)
2283        .await?;
2284    let payload = session
2285        .peer
2286        .forward_request(session.connection_id, host_connection_id, request)
2287        .await?;
2288    response.send(payload)?;
2289    Ok(())
2290}
2291
2292/// forward a project request to the host. These requests are disallowed
2293/// for guests.
2294async fn forward_mutating_project_request<T>(
2295    request: T,
2296    response: Response<T>,
2297    session: Session,
2298) -> Result<()>
2299where
2300    T: EntityMessage + RequestMessage,
2301{
2302    let project_id = ProjectId::from_proto(request.remote_entity_id());
2303
2304    let host_connection_id = session
2305        .db()
2306        .await
2307        .host_for_mutating_project_request(project_id, session.connection_id)
2308        .await?;
2309    let payload = session
2310        .peer
2311        .forward_request(session.connection_id, host_connection_id, request)
2312        .await?;
2313    response.send(payload)?;
2314    Ok(())
2315}
2316
2317/// Notify other participants that a new buffer has been created
2318async fn create_buffer_for_peer(
2319    request: proto::CreateBufferForPeer,
2320    session: Session,
2321) -> Result<()> {
2322    session
2323        .db()
2324        .await
2325        .check_user_is_project_host(
2326            ProjectId::from_proto(request.project_id),
2327            session.connection_id,
2328        )
2329        .await?;
2330    let peer_id = request.peer_id.context("invalid peer id")?;
2331    session
2332        .peer
2333        .forward_send(session.connection_id, peer_id.into(), request)?;
2334    Ok(())
2335}
2336
2337/// Notify other participants that a buffer has been updated. This is
2338/// allowed for guests as long as the update is limited to selections.
2339async fn update_buffer(
2340    request: proto::UpdateBuffer,
2341    response: Response<proto::UpdateBuffer>,
2342    session: Session,
2343) -> Result<()> {
2344    let project_id = ProjectId::from_proto(request.project_id);
2345    let mut capability = Capability::ReadOnly;
2346
2347    for op in request.operations.iter() {
2348        match op.variant {
2349            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2350            Some(_) => capability = Capability::ReadWrite,
2351        }
2352    }
2353
2354    let host = {
2355        let guard = session
2356            .db()
2357            .await
2358            .connections_for_buffer_update(project_id, session.connection_id, capability)
2359            .await?;
2360
2361        let (host, guests) = &*guard;
2362
2363        broadcast(
2364            Some(session.connection_id),
2365            guests.clone(),
2366            |connection_id| {
2367                session
2368                    .peer
2369                    .forward_send(session.connection_id, connection_id, request.clone())
2370            },
2371        );
2372
2373        *host
2374    };
2375
2376    if host != session.connection_id {
2377        session
2378            .peer
2379            .forward_request(session.connection_id, host, request.clone())
2380            .await?;
2381    }
2382
2383    response.send(proto::Ack {})?;
2384    Ok(())
2385}
2386
2387async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2388    let project_id = ProjectId::from_proto(message.project_id);
2389
2390    let operation = message.operation.as_ref().context("invalid operation")?;
2391    let capability = match operation.variant.as_ref() {
2392        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2393            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2394                match buffer_op.variant {
2395                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2396                        Capability::ReadOnly
2397                    }
2398                    _ => Capability::ReadWrite,
2399                }
2400            } else {
2401                Capability::ReadWrite
2402            }
2403        }
2404        Some(_) => Capability::ReadWrite,
2405        None => Capability::ReadOnly,
2406    };
2407
2408    let guard = session
2409        .db()
2410        .await
2411        .connections_for_buffer_update(project_id, session.connection_id, capability)
2412        .await?;
2413
2414    let (host, guests) = &*guard;
2415
2416    broadcast(
2417        Some(session.connection_id),
2418        guests.iter().chain([host]).copied(),
2419        |connection_id| {
2420            session
2421                .peer
2422                .forward_send(session.connection_id, connection_id, message.clone())
2423        },
2424    );
2425
2426    Ok(())
2427}
2428
2429/// Notify other participants that a project has been updated.
2430async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2431    request: T,
2432    session: Session,
2433) -> Result<()> {
2434    let project_id = ProjectId::from_proto(request.remote_entity_id());
2435    let project_connection_ids = session
2436        .db()
2437        .await
2438        .project_connection_ids(project_id, session.connection_id, false)
2439        .await?;
2440
2441    broadcast(
2442        Some(session.connection_id),
2443        project_connection_ids.iter().copied(),
2444        |connection_id| {
2445            session
2446                .peer
2447                .forward_send(session.connection_id, connection_id, request.clone())
2448        },
2449    );
2450    Ok(())
2451}
2452
2453/// Start following another user in a call.
2454async fn follow(
2455    request: proto::Follow,
2456    response: Response<proto::Follow>,
2457    session: Session,
2458) -> Result<()> {
2459    let room_id = RoomId::from_proto(request.room_id);
2460    let project_id = request.project_id.map(ProjectId::from_proto);
2461    let leader_id = request.leader_id.context("invalid leader id")?.into();
2462    let follower_id = session.connection_id;
2463
2464    session
2465        .db()
2466        .await
2467        .check_room_participants(room_id, leader_id, session.connection_id)
2468        .await?;
2469
2470    let response_payload = session
2471        .peer
2472        .forward_request(session.connection_id, leader_id, request)
2473        .await?;
2474    response.send(response_payload)?;
2475
2476    if let Some(project_id) = project_id {
2477        let room = session
2478            .db()
2479            .await
2480            .follow(room_id, project_id, leader_id, follower_id)
2481            .await?;
2482        room_updated(&room, &session.peer);
2483    }
2484
2485    Ok(())
2486}
2487
2488/// Stop following another user in a call.
2489async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2490    let room_id = RoomId::from_proto(request.room_id);
2491    let project_id = request.project_id.map(ProjectId::from_proto);
2492    let leader_id = request.leader_id.context("invalid leader id")?.into();
2493    let follower_id = session.connection_id;
2494
2495    session
2496        .db()
2497        .await
2498        .check_room_participants(room_id, leader_id, session.connection_id)
2499        .await?;
2500
2501    session
2502        .peer
2503        .forward_send(session.connection_id, leader_id, request)?;
2504
2505    if let Some(project_id) = project_id {
2506        let room = session
2507            .db()
2508            .await
2509            .unfollow(room_id, project_id, leader_id, follower_id)
2510            .await?;
2511        room_updated(&room, &session.peer);
2512    }
2513
2514    Ok(())
2515}
2516
2517/// Notify everyone following you of your current location.
2518async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2519    let room_id = RoomId::from_proto(request.room_id);
2520    let database = session.db.lock().await;
2521
2522    let connection_ids = if let Some(project_id) = request.project_id {
2523        let project_id = ProjectId::from_proto(project_id);
2524        database
2525            .project_connection_ids(project_id, session.connection_id, true)
2526            .await?
2527    } else {
2528        database
2529            .room_connection_ids(room_id, session.connection_id)
2530            .await?
2531    };
2532
2533    // For now, don't send view update messages back to that view's current leader.
2534    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2535        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2536        _ => None,
2537    });
2538
2539    for connection_id in connection_ids.iter().cloned() {
2540        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2541            session
2542                .peer
2543                .forward_send(session.connection_id, connection_id, request.clone())?;
2544        }
2545    }
2546    Ok(())
2547}
2548
2549/// Get public data about users.
2550async fn get_users(
2551    request: proto::GetUsers,
2552    response: Response<proto::GetUsers>,
2553    session: Session,
2554) -> Result<()> {
2555    let user_ids = request
2556        .user_ids
2557        .into_iter()
2558        .map(UserId::from_proto)
2559        .collect();
2560    let users = session
2561        .db()
2562        .await
2563        .get_users_by_ids(user_ids)
2564        .await?
2565        .into_iter()
2566        .map(|user| proto::User {
2567            id: user.id.to_proto(),
2568            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2569            github_login: user.github_login,
2570            email: user.email_address,
2571            name: user.name,
2572        })
2573        .collect();
2574    response.send(proto::UsersResponse { users })?;
2575    Ok(())
2576}
2577
2578/// Search for users (to invite) buy Github login
2579async fn fuzzy_search_users(
2580    request: proto::FuzzySearchUsers,
2581    response: Response<proto::FuzzySearchUsers>,
2582    session: Session,
2583) -> Result<()> {
2584    let query = request.query;
2585    let users = match query.len() {
2586        0 => vec![],
2587        1 | 2 => session
2588            .db()
2589            .await
2590            .get_user_by_github_login(&query)
2591            .await?
2592            .into_iter()
2593            .collect(),
2594        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2595    };
2596    let users = users
2597        .into_iter()
2598        .filter(|user| user.id != session.user_id())
2599        .map(|user| proto::User {
2600            id: user.id.to_proto(),
2601            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2602            github_login: user.github_login,
2603            name: user.name,
2604            email: user.email_address,
2605        })
2606        .collect();
2607    response.send(proto::UsersResponse { users })?;
2608    Ok(())
2609}
2610
2611/// Send a contact request to another user.
2612async fn request_contact(
2613    request: proto::RequestContact,
2614    response: Response<proto::RequestContact>,
2615    session: Session,
2616) -> Result<()> {
2617    let requester_id = session.user_id();
2618    let responder_id = UserId::from_proto(request.responder_id);
2619    if requester_id == responder_id {
2620        return Err(anyhow!("cannot add yourself as a contact"))?;
2621    }
2622
2623    let notifications = session
2624        .db()
2625        .await
2626        .send_contact_request(requester_id, responder_id)
2627        .await?;
2628
2629    // Update outgoing contact requests of requester
2630    let mut update = proto::UpdateContacts::default();
2631    update.outgoing_requests.push(responder_id.to_proto());
2632    for connection_id in session
2633        .connection_pool()
2634        .await
2635        .user_connection_ids(requester_id)
2636    {
2637        session.peer.send(connection_id, update.clone())?;
2638    }
2639
2640    // Update incoming contact requests of responder
2641    let mut update = proto::UpdateContacts::default();
2642    update
2643        .incoming_requests
2644        .push(proto::IncomingContactRequest {
2645            requester_id: requester_id.to_proto(),
2646        });
2647    let connection_pool = session.connection_pool().await;
2648    for connection_id in connection_pool.user_connection_ids(responder_id) {
2649        session.peer.send(connection_id, update.clone())?;
2650    }
2651
2652    send_notifications(&connection_pool, &session.peer, notifications);
2653
2654    response.send(proto::Ack {})?;
2655    Ok(())
2656}
2657
2658/// Accept or decline a contact request
2659async fn respond_to_contact_request(
2660    request: proto::RespondToContactRequest,
2661    response: Response<proto::RespondToContactRequest>,
2662    session: Session,
2663) -> Result<()> {
2664    let responder_id = session.user_id();
2665    let requester_id = UserId::from_proto(request.requester_id);
2666    let db = session.db().await;
2667    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2668        db.dismiss_contact_notification(responder_id, requester_id)
2669            .await?;
2670    } else {
2671        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2672
2673        let notifications = db
2674            .respond_to_contact_request(responder_id, requester_id, accept)
2675            .await?;
2676        let requester_busy = db.is_user_busy(requester_id).await?;
2677        let responder_busy = db.is_user_busy(responder_id).await?;
2678
2679        let pool = session.connection_pool().await;
2680        // Update responder with new contact
2681        let mut update = proto::UpdateContacts::default();
2682        if accept {
2683            update
2684                .contacts
2685                .push(contact_for_user(requester_id, requester_busy, &pool));
2686        }
2687        update
2688            .remove_incoming_requests
2689            .push(requester_id.to_proto());
2690        for connection_id in pool.user_connection_ids(responder_id) {
2691            session.peer.send(connection_id, update.clone())?;
2692        }
2693
2694        // Update requester with new contact
2695        let mut update = proto::UpdateContacts::default();
2696        if accept {
2697            update
2698                .contacts
2699                .push(contact_for_user(responder_id, responder_busy, &pool));
2700        }
2701        update
2702            .remove_outgoing_requests
2703            .push(responder_id.to_proto());
2704
2705        for connection_id in pool.user_connection_ids(requester_id) {
2706            session.peer.send(connection_id, update.clone())?;
2707        }
2708
2709        send_notifications(&pool, &session.peer, notifications);
2710    }
2711
2712    response.send(proto::Ack {})?;
2713    Ok(())
2714}
2715
2716/// Remove a contact.
2717async fn remove_contact(
2718    request: proto::RemoveContact,
2719    response: Response<proto::RemoveContact>,
2720    session: Session,
2721) -> Result<()> {
2722    let requester_id = session.user_id();
2723    let responder_id = UserId::from_proto(request.user_id);
2724    let db = session.db().await;
2725    let (contact_accepted, deleted_notification_id) =
2726        db.remove_contact(requester_id, responder_id).await?;
2727
2728    let pool = session.connection_pool().await;
2729    // Update outgoing contact requests of requester
2730    let mut update = proto::UpdateContacts::default();
2731    if contact_accepted {
2732        update.remove_contacts.push(responder_id.to_proto());
2733    } else {
2734        update
2735            .remove_outgoing_requests
2736            .push(responder_id.to_proto());
2737    }
2738    for connection_id in pool.user_connection_ids(requester_id) {
2739        session.peer.send(connection_id, update.clone())?;
2740    }
2741
2742    // Update incoming contact requests of responder
2743    let mut update = proto::UpdateContacts::default();
2744    if contact_accepted {
2745        update.remove_contacts.push(requester_id.to_proto());
2746    } else {
2747        update
2748            .remove_incoming_requests
2749            .push(requester_id.to_proto());
2750    }
2751    for connection_id in pool.user_connection_ids(responder_id) {
2752        session.peer.send(connection_id, update.clone())?;
2753        if let Some(notification_id) = deleted_notification_id {
2754            session.peer.send(
2755                connection_id,
2756                proto::DeleteNotification {
2757                    notification_id: notification_id.to_proto(),
2758                },
2759            )?;
2760        }
2761    }
2762
2763    response.send(proto::Ack {})?;
2764    Ok(())
2765}
2766
2767fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2768    version.0.minor() < 139
2769}
2770
2771async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2772    if is_staff {
2773        return Ok(proto::Plan::ZedPro);
2774    }
2775
2776    let subscription = db.get_active_billing_subscription(user_id).await?;
2777    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2778
2779    let plan = if let Some(subscription_kind) = subscription_kind {
2780        match subscription_kind {
2781            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2782            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2783            SubscriptionKind::ZedFree => proto::Plan::Free,
2784        }
2785    } else {
2786        proto::Plan::Free
2787    };
2788
2789    Ok(plan)
2790}
2791
2792async fn make_update_user_plan_message(
2793    user: &User,
2794    is_staff: bool,
2795    db: &Arc<Database>,
2796    llm_db: Option<Arc<LlmDatabase>>,
2797) -> Result<proto::UpdateUserPlan> {
2798    let feature_flags = db.get_user_flags(user.id).await?;
2799    let plan = current_plan(db, user.id, is_staff).await?;
2800    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2801    let billing_preferences = db.get_billing_preferences(user.id).await?;
2802
2803    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2804        let subscription = db.get_active_billing_subscription(user.id).await?;
2805
2806        let subscription_period =
2807            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2808
2809        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2810            llm_db
2811                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2812                .await?
2813        } else {
2814            None
2815        };
2816
2817        (subscription_period, usage)
2818    } else {
2819        (None, None)
2820    };
2821
2822    let bypass_account_age_check = feature_flags
2823        .iter()
2824        .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2825    let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2826        && !bypass_account_age_check
2827        && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2828
2829    Ok(proto::UpdateUserPlan {
2830        plan: plan.into(),
2831        trial_started_at: billing_customer
2832            .as_ref()
2833            .and_then(|billing_customer| billing_customer.trial_started_at)
2834            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2835        is_usage_based_billing_enabled: if is_staff {
2836            Some(true)
2837        } else {
2838            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2839        },
2840        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2841            proto::SubscriptionPeriod {
2842                started_at: started_at.timestamp() as u64,
2843                ended_at: ended_at.timestamp() as u64,
2844            }
2845        }),
2846        account_too_young: Some(account_too_young),
2847        has_overdue_invoices: billing_customer
2848            .map(|billing_customer| billing_customer.has_overdue_invoices),
2849        usage: usage.map(|usage| {
2850            let plan = match plan {
2851                proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2852                proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2853                proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2854            };
2855
2856            let model_requests_limit = match plan.model_requests_limit() {
2857                zed_llm_client::UsageLimit::Limited(limit) => {
2858                    let limit = if plan == zed_llm_client::Plan::ZedProTrial
2859                        && feature_flags
2860                            .iter()
2861                            .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2862                    {
2863                        1_000
2864                    } else {
2865                        limit
2866                    };
2867
2868                    zed_llm_client::UsageLimit::Limited(limit)
2869                }
2870                zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2871            };
2872
2873            proto::SubscriptionUsage {
2874                model_requests_usage_amount: usage.model_requests as u32,
2875                model_requests_usage_limit: Some(proto::UsageLimit {
2876                    variant: Some(match model_requests_limit {
2877                        zed_llm_client::UsageLimit::Limited(limit) => {
2878                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2879                                limit: limit as u32,
2880                            })
2881                        }
2882                        zed_llm_client::UsageLimit::Unlimited => {
2883                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2884                        }
2885                    }),
2886                }),
2887                edit_predictions_usage_amount: usage.edit_predictions as u32,
2888                edit_predictions_usage_limit: Some(proto::UsageLimit {
2889                    variant: Some(match plan.edit_predictions_limit() {
2890                        zed_llm_client::UsageLimit::Limited(limit) => {
2891                            proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2892                                limit: limit as u32,
2893                            })
2894                        }
2895                        zed_llm_client::UsageLimit::Unlimited => {
2896                            proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2897                        }
2898                    }),
2899                }),
2900            }
2901        }),
2902    })
2903}
2904
2905async fn update_user_plan(session: &Session) -> Result<()> {
2906    let db = session.db().await;
2907
2908    let update_user_plan = make_update_user_plan_message(
2909        session.principal.user(),
2910        session.is_staff(),
2911        &db.0,
2912        session.app_state.llm_db.clone(),
2913    )
2914    .await?;
2915
2916    session
2917        .peer
2918        .send(session.connection_id, update_user_plan)
2919        .trace_err();
2920
2921    Ok(())
2922}
2923
2924async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2925    subscribe_user_to_channels(session.user_id(), &session).await?;
2926    Ok(())
2927}
2928
2929async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2930    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2931    let mut pool = session.connection_pool().await;
2932    for membership in &channels_for_user.channel_memberships {
2933        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2934    }
2935    session.peer.send(
2936        session.connection_id,
2937        build_update_user_channels(&channels_for_user),
2938    )?;
2939    session.peer.send(
2940        session.connection_id,
2941        build_channels_update(channels_for_user),
2942    )?;
2943    Ok(())
2944}
2945
2946/// Creates a new channel.
2947async fn create_channel(
2948    request: proto::CreateChannel,
2949    response: Response<proto::CreateChannel>,
2950    session: Session,
2951) -> Result<()> {
2952    let db = session.db().await;
2953
2954    let parent_id = request.parent_id.map(ChannelId::from_proto);
2955    let (channel, membership) = db
2956        .create_channel(&request.name, parent_id, session.user_id())
2957        .await?;
2958
2959    let root_id = channel.root_id();
2960    let channel = Channel::from_model(channel);
2961
2962    response.send(proto::CreateChannelResponse {
2963        channel: Some(channel.to_proto()),
2964        parent_id: request.parent_id,
2965    })?;
2966
2967    let mut connection_pool = session.connection_pool().await;
2968    if let Some(membership) = membership {
2969        connection_pool.subscribe_to_channel(
2970            membership.user_id,
2971            membership.channel_id,
2972            membership.role,
2973        );
2974        let update = proto::UpdateUserChannels {
2975            channel_memberships: vec![proto::ChannelMembership {
2976                channel_id: membership.channel_id.to_proto(),
2977                role: membership.role.into(),
2978            }],
2979            ..Default::default()
2980        };
2981        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2982            session.peer.send(connection_id, update.clone())?;
2983        }
2984    }
2985
2986    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2987        if !role.can_see_channel(channel.visibility) {
2988            continue;
2989        }
2990
2991        let update = proto::UpdateChannels {
2992            channels: vec![channel.to_proto()],
2993            ..Default::default()
2994        };
2995        session.peer.send(connection_id, update.clone())?;
2996    }
2997
2998    Ok(())
2999}
3000
3001/// Delete a channel
3002async fn delete_channel(
3003    request: proto::DeleteChannel,
3004    response: Response<proto::DeleteChannel>,
3005    session: Session,
3006) -> Result<()> {
3007    let db = session.db().await;
3008
3009    let channel_id = request.channel_id;
3010    let (root_channel, removed_channels) = db
3011        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3012        .await?;
3013    response.send(proto::Ack {})?;
3014
3015    // Notify members of removed channels
3016    let mut update = proto::UpdateChannels::default();
3017    update
3018        .delete_channels
3019        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3020
3021    let connection_pool = session.connection_pool().await;
3022    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3023        session.peer.send(connection_id, update.clone())?;
3024    }
3025
3026    Ok(())
3027}
3028
3029/// Invite someone to join a channel.
3030async fn invite_channel_member(
3031    request: proto::InviteChannelMember,
3032    response: Response<proto::InviteChannelMember>,
3033    session: Session,
3034) -> Result<()> {
3035    let db = session.db().await;
3036    let channel_id = ChannelId::from_proto(request.channel_id);
3037    let invitee_id = UserId::from_proto(request.user_id);
3038    let InviteMemberResult {
3039        channel,
3040        notifications,
3041    } = db
3042        .invite_channel_member(
3043            channel_id,
3044            invitee_id,
3045            session.user_id(),
3046            request.role().into(),
3047        )
3048        .await?;
3049
3050    let update = proto::UpdateChannels {
3051        channel_invitations: vec![channel.to_proto()],
3052        ..Default::default()
3053    };
3054
3055    let connection_pool = session.connection_pool().await;
3056    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3057        session.peer.send(connection_id, update.clone())?;
3058    }
3059
3060    send_notifications(&connection_pool, &session.peer, notifications);
3061
3062    response.send(proto::Ack {})?;
3063    Ok(())
3064}
3065
3066/// remove someone from a channel
3067async fn remove_channel_member(
3068    request: proto::RemoveChannelMember,
3069    response: Response<proto::RemoveChannelMember>,
3070    session: Session,
3071) -> Result<()> {
3072    let db = session.db().await;
3073    let channel_id = ChannelId::from_proto(request.channel_id);
3074    let member_id = UserId::from_proto(request.user_id);
3075
3076    let RemoveChannelMemberResult {
3077        membership_update,
3078        notification_id,
3079    } = db
3080        .remove_channel_member(channel_id, member_id, session.user_id())
3081        .await?;
3082
3083    let mut connection_pool = session.connection_pool().await;
3084    notify_membership_updated(
3085        &mut connection_pool,
3086        membership_update,
3087        member_id,
3088        &session.peer,
3089    );
3090    for connection_id in connection_pool.user_connection_ids(member_id) {
3091        if let Some(notification_id) = notification_id {
3092            session
3093                .peer
3094                .send(
3095                    connection_id,
3096                    proto::DeleteNotification {
3097                        notification_id: notification_id.to_proto(),
3098                    },
3099                )
3100                .trace_err();
3101        }
3102    }
3103
3104    response.send(proto::Ack {})?;
3105    Ok(())
3106}
3107
3108/// Toggle the channel between public and private.
3109/// Care is taken to maintain the invariant that public channels only descend from public channels,
3110/// (though members-only channels can appear at any point in the hierarchy).
3111async fn set_channel_visibility(
3112    request: proto::SetChannelVisibility,
3113    response: Response<proto::SetChannelVisibility>,
3114    session: Session,
3115) -> Result<()> {
3116    let db = session.db().await;
3117    let channel_id = ChannelId::from_proto(request.channel_id);
3118    let visibility = request.visibility().into();
3119
3120    let channel_model = db
3121        .set_channel_visibility(channel_id, visibility, session.user_id())
3122        .await?;
3123    let root_id = channel_model.root_id();
3124    let channel = Channel::from_model(channel_model);
3125
3126    let mut connection_pool = session.connection_pool().await;
3127    for (user_id, role) in connection_pool
3128        .channel_user_ids(root_id)
3129        .collect::<Vec<_>>()
3130        .into_iter()
3131    {
3132        let update = if role.can_see_channel(channel.visibility) {
3133            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3134            proto::UpdateChannels {
3135                channels: vec![channel.to_proto()],
3136                ..Default::default()
3137            }
3138        } else {
3139            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3140            proto::UpdateChannels {
3141                delete_channels: vec![channel.id.to_proto()],
3142                ..Default::default()
3143            }
3144        };
3145
3146        for connection_id in connection_pool.user_connection_ids(user_id) {
3147            session.peer.send(connection_id, update.clone())?;
3148        }
3149    }
3150
3151    response.send(proto::Ack {})?;
3152    Ok(())
3153}
3154
3155/// Alter the role for a user in the channel.
3156async fn set_channel_member_role(
3157    request: proto::SetChannelMemberRole,
3158    response: Response<proto::SetChannelMemberRole>,
3159    session: Session,
3160) -> Result<()> {
3161    let db = session.db().await;
3162    let channel_id = ChannelId::from_proto(request.channel_id);
3163    let member_id = UserId::from_proto(request.user_id);
3164    let result = db
3165        .set_channel_member_role(
3166            channel_id,
3167            session.user_id(),
3168            member_id,
3169            request.role().into(),
3170        )
3171        .await?;
3172
3173    match result {
3174        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3175            let mut connection_pool = session.connection_pool().await;
3176            notify_membership_updated(
3177                &mut connection_pool,
3178                membership_update,
3179                member_id,
3180                &session.peer,
3181            )
3182        }
3183        db::SetMemberRoleResult::InviteUpdated(channel) => {
3184            let update = proto::UpdateChannels {
3185                channel_invitations: vec![channel.to_proto()],
3186                ..Default::default()
3187            };
3188
3189            for connection_id in session
3190                .connection_pool()
3191                .await
3192                .user_connection_ids(member_id)
3193            {
3194                session.peer.send(connection_id, update.clone())?;
3195            }
3196        }
3197    }
3198
3199    response.send(proto::Ack {})?;
3200    Ok(())
3201}
3202
3203/// Change the name of a channel
3204async fn rename_channel(
3205    request: proto::RenameChannel,
3206    response: Response<proto::RenameChannel>,
3207    session: Session,
3208) -> Result<()> {
3209    let db = session.db().await;
3210    let channel_id = ChannelId::from_proto(request.channel_id);
3211    let channel_model = db
3212        .rename_channel(channel_id, session.user_id(), &request.name)
3213        .await?;
3214    let root_id = channel_model.root_id();
3215    let channel = Channel::from_model(channel_model);
3216
3217    response.send(proto::RenameChannelResponse {
3218        channel: Some(channel.to_proto()),
3219    })?;
3220
3221    let connection_pool = session.connection_pool().await;
3222    let update = proto::UpdateChannels {
3223        channels: vec![channel.to_proto()],
3224        ..Default::default()
3225    };
3226    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3227        if role.can_see_channel(channel.visibility) {
3228            session.peer.send(connection_id, update.clone())?;
3229        }
3230    }
3231
3232    Ok(())
3233}
3234
3235/// Move a channel to a new parent.
3236async fn move_channel(
3237    request: proto::MoveChannel,
3238    response: Response<proto::MoveChannel>,
3239    session: Session,
3240) -> Result<()> {
3241    let channel_id = ChannelId::from_proto(request.channel_id);
3242    let to = ChannelId::from_proto(request.to);
3243
3244    let (root_id, channels) = session
3245        .db()
3246        .await
3247        .move_channel(channel_id, to, session.user_id())
3248        .await?;
3249
3250    let connection_pool = session.connection_pool().await;
3251    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3252        let channels = channels
3253            .iter()
3254            .filter_map(|channel| {
3255                if role.can_see_channel(channel.visibility) {
3256                    Some(channel.to_proto())
3257                } else {
3258                    None
3259                }
3260            })
3261            .collect::<Vec<_>>();
3262        if channels.is_empty() {
3263            continue;
3264        }
3265
3266        let update = proto::UpdateChannels {
3267            channels,
3268            ..Default::default()
3269        };
3270
3271        session.peer.send(connection_id, update.clone())?;
3272    }
3273
3274    response.send(Ack {})?;
3275    Ok(())
3276}
3277
3278async fn reorder_channel(
3279    request: proto::ReorderChannel,
3280    response: Response<proto::ReorderChannel>,
3281    session: Session,
3282) -> Result<()> {
3283    let channel_id = ChannelId::from_proto(request.channel_id);
3284    let direction = request.direction();
3285
3286    let updated_channels = session
3287        .db()
3288        .await
3289        .reorder_channel(channel_id, direction, session.user_id())
3290        .await?;
3291
3292    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3293        let connection_pool = session.connection_pool().await;
3294        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3295            let channels = updated_channels
3296                .iter()
3297                .filter_map(|channel| {
3298                    if role.can_see_channel(channel.visibility) {
3299                        Some(channel.to_proto())
3300                    } else {
3301                        None
3302                    }
3303                })
3304                .collect::<Vec<_>>();
3305
3306            if channels.is_empty() {
3307                continue;
3308            }
3309
3310            let update = proto::UpdateChannels {
3311                channels,
3312                ..Default::default()
3313            };
3314
3315            session.peer.send(connection_id, update.clone())?;
3316        }
3317    }
3318
3319    response.send(Ack {})?;
3320    Ok(())
3321}
3322
3323/// Get the list of channel members
3324async fn get_channel_members(
3325    request: proto::GetChannelMembers,
3326    response: Response<proto::GetChannelMembers>,
3327    session: Session,
3328) -> Result<()> {
3329    let db = session.db().await;
3330    let channel_id = ChannelId::from_proto(request.channel_id);
3331    let limit = if request.limit == 0 {
3332        u16::MAX as u64
3333    } else {
3334        request.limit
3335    };
3336    let (members, users) = db
3337        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3338        .await?;
3339    response.send(proto::GetChannelMembersResponse { members, users })?;
3340    Ok(())
3341}
3342
3343/// Accept or decline a channel invitation.
3344async fn respond_to_channel_invite(
3345    request: proto::RespondToChannelInvite,
3346    response: Response<proto::RespondToChannelInvite>,
3347    session: Session,
3348) -> Result<()> {
3349    let db = session.db().await;
3350    let channel_id = ChannelId::from_proto(request.channel_id);
3351    let RespondToChannelInvite {
3352        membership_update,
3353        notifications,
3354    } = db
3355        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3356        .await?;
3357
3358    let mut connection_pool = session.connection_pool().await;
3359    if let Some(membership_update) = membership_update {
3360        notify_membership_updated(
3361            &mut connection_pool,
3362            membership_update,
3363            session.user_id(),
3364            &session.peer,
3365        );
3366    } else {
3367        let update = proto::UpdateChannels {
3368            remove_channel_invitations: vec![channel_id.to_proto()],
3369            ..Default::default()
3370        };
3371
3372        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3373            session.peer.send(connection_id, update.clone())?;
3374        }
3375    };
3376
3377    send_notifications(&connection_pool, &session.peer, notifications);
3378
3379    response.send(proto::Ack {})?;
3380
3381    Ok(())
3382}
3383
3384/// Join the channels' room
3385async fn join_channel(
3386    request: proto::JoinChannel,
3387    response: Response<proto::JoinChannel>,
3388    session: Session,
3389) -> Result<()> {
3390    let channel_id = ChannelId::from_proto(request.channel_id);
3391    join_channel_internal(channel_id, Box::new(response), session).await
3392}
3393
3394trait JoinChannelInternalResponse {
3395    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3396}
3397impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3398    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3399        Response::<proto::JoinChannel>::send(self, result)
3400    }
3401}
3402impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3403    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3404        Response::<proto::JoinRoom>::send(self, result)
3405    }
3406}
3407
3408async fn join_channel_internal(
3409    channel_id: ChannelId,
3410    response: Box<impl JoinChannelInternalResponse>,
3411    session: Session,
3412) -> Result<()> {
3413    let joined_room = {
3414        let mut db = session.db().await;
3415        // If zed quits without leaving the room, and the user re-opens zed before the
3416        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3417        // room they were in.
3418        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3419            tracing::info!(
3420                stale_connection_id = %connection,
3421                "cleaning up stale connection",
3422            );
3423            drop(db);
3424            leave_room_for_session(&session, connection).await?;
3425            db = session.db().await;
3426        }
3427
3428        let (joined_room, membership_updated, role) = db
3429            .join_channel(channel_id, session.user_id(), session.connection_id)
3430            .await?;
3431
3432        let live_kit_connection_info =
3433            session
3434                .app_state
3435                .livekit_client
3436                .as_ref()
3437                .and_then(|live_kit| {
3438                    let (can_publish, token) = if role == ChannelRole::Guest {
3439                        (
3440                            false,
3441                            live_kit
3442                                .guest_token(
3443                                    &joined_room.room.livekit_room,
3444                                    &session.user_id().to_string(),
3445                                )
3446                                .trace_err()?,
3447                        )
3448                    } else {
3449                        (
3450                            true,
3451                            live_kit
3452                                .room_token(
3453                                    &joined_room.room.livekit_room,
3454                                    &session.user_id().to_string(),
3455                                )
3456                                .trace_err()?,
3457                        )
3458                    };
3459
3460                    Some(LiveKitConnectionInfo {
3461                        server_url: live_kit.url().into(),
3462                        token,
3463                        can_publish,
3464                    })
3465                });
3466
3467        response.send(proto::JoinRoomResponse {
3468            room: Some(joined_room.room.clone()),
3469            channel_id: joined_room
3470                .channel
3471                .as_ref()
3472                .map(|channel| channel.id.to_proto()),
3473            live_kit_connection_info,
3474        })?;
3475
3476        let mut connection_pool = session.connection_pool().await;
3477        if let Some(membership_updated) = membership_updated {
3478            notify_membership_updated(
3479                &mut connection_pool,
3480                membership_updated,
3481                session.user_id(),
3482                &session.peer,
3483            );
3484        }
3485
3486        room_updated(&joined_room.room, &session.peer);
3487
3488        joined_room
3489    };
3490
3491    channel_updated(
3492        &joined_room.channel.context("channel not returned")?,
3493        &joined_room.room,
3494        &session.peer,
3495        &*session.connection_pool().await,
3496    );
3497
3498    update_user_contacts(session.user_id(), &session).await?;
3499    Ok(())
3500}
3501
3502/// Start editing the channel notes
3503async fn join_channel_buffer(
3504    request: proto::JoinChannelBuffer,
3505    response: Response<proto::JoinChannelBuffer>,
3506    session: Session,
3507) -> Result<()> {
3508    let db = session.db().await;
3509    let channel_id = ChannelId::from_proto(request.channel_id);
3510
3511    let open_response = db
3512        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3513        .await?;
3514
3515    let collaborators = open_response.collaborators.clone();
3516    response.send(open_response)?;
3517
3518    let update = UpdateChannelBufferCollaborators {
3519        channel_id: channel_id.to_proto(),
3520        collaborators: collaborators.clone(),
3521    };
3522    channel_buffer_updated(
3523        session.connection_id,
3524        collaborators
3525            .iter()
3526            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3527        &update,
3528        &session.peer,
3529    );
3530
3531    Ok(())
3532}
3533
3534/// Edit the channel notes
3535async fn update_channel_buffer(
3536    request: proto::UpdateChannelBuffer,
3537    session: Session,
3538) -> Result<()> {
3539    let db = session.db().await;
3540    let channel_id = ChannelId::from_proto(request.channel_id);
3541
3542    let (collaborators, epoch, version) = db
3543        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3544        .await?;
3545
3546    channel_buffer_updated(
3547        session.connection_id,
3548        collaborators.clone(),
3549        &proto::UpdateChannelBuffer {
3550            channel_id: channel_id.to_proto(),
3551            operations: request.operations,
3552        },
3553        &session.peer,
3554    );
3555
3556    let pool = &*session.connection_pool().await;
3557
3558    let non_collaborators =
3559        pool.channel_connection_ids(channel_id)
3560            .filter_map(|(connection_id, _)| {
3561                if collaborators.contains(&connection_id) {
3562                    None
3563                } else {
3564                    Some(connection_id)
3565                }
3566            });
3567
3568    broadcast(None, non_collaborators, |peer_id| {
3569        session.peer.send(
3570            peer_id,
3571            proto::UpdateChannels {
3572                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3573                    channel_id: channel_id.to_proto(),
3574                    epoch: epoch as u64,
3575                    version: version.clone(),
3576                }],
3577                ..Default::default()
3578            },
3579        )
3580    });
3581
3582    Ok(())
3583}
3584
3585/// Rejoin the channel notes after a connection blip
3586async fn rejoin_channel_buffers(
3587    request: proto::RejoinChannelBuffers,
3588    response: Response<proto::RejoinChannelBuffers>,
3589    session: Session,
3590) -> Result<()> {
3591    let db = session.db().await;
3592    let buffers = db
3593        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3594        .await?;
3595
3596    for rejoined_buffer in &buffers {
3597        let collaborators_to_notify = rejoined_buffer
3598            .buffer
3599            .collaborators
3600            .iter()
3601            .filter_map(|c| Some(c.peer_id?.into()));
3602        channel_buffer_updated(
3603            session.connection_id,
3604            collaborators_to_notify,
3605            &proto::UpdateChannelBufferCollaborators {
3606                channel_id: rejoined_buffer.buffer.channel_id,
3607                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3608            },
3609            &session.peer,
3610        );
3611    }
3612
3613    response.send(proto::RejoinChannelBuffersResponse {
3614        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3615    })?;
3616
3617    Ok(())
3618}
3619
3620/// Stop editing the channel notes
3621async fn leave_channel_buffer(
3622    request: proto::LeaveChannelBuffer,
3623    response: Response<proto::LeaveChannelBuffer>,
3624    session: Session,
3625) -> Result<()> {
3626    let db = session.db().await;
3627    let channel_id = ChannelId::from_proto(request.channel_id);
3628
3629    let left_buffer = db
3630        .leave_channel_buffer(channel_id, session.connection_id)
3631        .await?;
3632
3633    response.send(Ack {})?;
3634
3635    channel_buffer_updated(
3636        session.connection_id,
3637        left_buffer.connections,
3638        &proto::UpdateChannelBufferCollaborators {
3639            channel_id: channel_id.to_proto(),
3640            collaborators: left_buffer.collaborators,
3641        },
3642        &session.peer,
3643    );
3644
3645    Ok(())
3646}
3647
3648fn channel_buffer_updated<T: EnvelopedMessage>(
3649    sender_id: ConnectionId,
3650    collaborators: impl IntoIterator<Item = ConnectionId>,
3651    message: &T,
3652    peer: &Peer,
3653) {
3654    broadcast(Some(sender_id), collaborators, |peer_id| {
3655        peer.send(peer_id, message.clone())
3656    });
3657}
3658
3659fn send_notifications(
3660    connection_pool: &ConnectionPool,
3661    peer: &Peer,
3662    notifications: db::NotificationBatch,
3663) {
3664    for (user_id, notification) in notifications {
3665        for connection_id in connection_pool.user_connection_ids(user_id) {
3666            if let Err(error) = peer.send(
3667                connection_id,
3668                proto::AddNotification {
3669                    notification: Some(notification.clone()),
3670                },
3671            ) {
3672                tracing::error!(
3673                    "failed to send notification to {:?} {}",
3674                    connection_id,
3675                    error
3676                );
3677            }
3678        }
3679    }
3680}
3681
3682/// Send a message to the channel
3683async fn send_channel_message(
3684    request: proto::SendChannelMessage,
3685    response: Response<proto::SendChannelMessage>,
3686    session: Session,
3687) -> Result<()> {
3688    // Validate the message body.
3689    let body = request.body.trim().to_string();
3690    if body.len() > MAX_MESSAGE_LEN {
3691        return Err(anyhow!("message is too long"))?;
3692    }
3693    if body.is_empty() {
3694        return Err(anyhow!("message can't be blank"))?;
3695    }
3696
3697    // TODO: adjust mentions if body is trimmed
3698
3699    let timestamp = OffsetDateTime::now_utc();
3700    let nonce = request.nonce.context("nonce can't be blank")?;
3701
3702    let channel_id = ChannelId::from_proto(request.channel_id);
3703    let CreatedChannelMessage {
3704        message_id,
3705        participant_connection_ids,
3706        notifications,
3707    } = session
3708        .db()
3709        .await
3710        .create_channel_message(
3711            channel_id,
3712            session.user_id(),
3713            &body,
3714            &request.mentions,
3715            timestamp,
3716            nonce.clone().into(),
3717            request.reply_to_message_id.map(MessageId::from_proto),
3718        )
3719        .await?;
3720
3721    let message = proto::ChannelMessage {
3722        sender_id: session.user_id().to_proto(),
3723        id: message_id.to_proto(),
3724        body,
3725        mentions: request.mentions,
3726        timestamp: timestamp.unix_timestamp() as u64,
3727        nonce: Some(nonce),
3728        reply_to_message_id: request.reply_to_message_id,
3729        edited_at: None,
3730    };
3731    broadcast(
3732        Some(session.connection_id),
3733        participant_connection_ids.clone(),
3734        |connection| {
3735            session.peer.send(
3736                connection,
3737                proto::ChannelMessageSent {
3738                    channel_id: channel_id.to_proto(),
3739                    message: Some(message.clone()),
3740                },
3741            )
3742        },
3743    );
3744    response.send(proto::SendChannelMessageResponse {
3745        message: Some(message),
3746    })?;
3747
3748    let pool = &*session.connection_pool().await;
3749    let non_participants =
3750        pool.channel_connection_ids(channel_id)
3751            .filter_map(|(connection_id, _)| {
3752                if participant_connection_ids.contains(&connection_id) {
3753                    None
3754                } else {
3755                    Some(connection_id)
3756                }
3757            });
3758    broadcast(None, non_participants, |peer_id| {
3759        session.peer.send(
3760            peer_id,
3761            proto::UpdateChannels {
3762                latest_channel_message_ids: vec![proto::ChannelMessageId {
3763                    channel_id: channel_id.to_proto(),
3764                    message_id: message_id.to_proto(),
3765                }],
3766                ..Default::default()
3767            },
3768        )
3769    });
3770    send_notifications(pool, &session.peer, notifications);
3771
3772    Ok(())
3773}
3774
3775/// Delete a channel message
3776async fn remove_channel_message(
3777    request: proto::RemoveChannelMessage,
3778    response: Response<proto::RemoveChannelMessage>,
3779    session: Session,
3780) -> Result<()> {
3781    let channel_id = ChannelId::from_proto(request.channel_id);
3782    let message_id = MessageId::from_proto(request.message_id);
3783    let (connection_ids, existing_notification_ids) = session
3784        .db()
3785        .await
3786        .remove_channel_message(channel_id, message_id, session.user_id())
3787        .await?;
3788
3789    broadcast(
3790        Some(session.connection_id),
3791        connection_ids,
3792        move |connection| {
3793            session.peer.send(connection, request.clone())?;
3794
3795            for notification_id in &existing_notification_ids {
3796                session.peer.send(
3797                    connection,
3798                    proto::DeleteNotification {
3799                        notification_id: (*notification_id).to_proto(),
3800                    },
3801                )?;
3802            }
3803
3804            Ok(())
3805        },
3806    );
3807    response.send(proto::Ack {})?;
3808    Ok(())
3809}
3810
3811async fn update_channel_message(
3812    request: proto::UpdateChannelMessage,
3813    response: Response<proto::UpdateChannelMessage>,
3814    session: Session,
3815) -> Result<()> {
3816    let channel_id = ChannelId::from_proto(request.channel_id);
3817    let message_id = MessageId::from_proto(request.message_id);
3818    let updated_at = OffsetDateTime::now_utc();
3819    let UpdatedChannelMessage {
3820        message_id,
3821        participant_connection_ids,
3822        notifications,
3823        reply_to_message_id,
3824        timestamp,
3825        deleted_mention_notification_ids,
3826        updated_mention_notifications,
3827    } = session
3828        .db()
3829        .await
3830        .update_channel_message(
3831            channel_id,
3832            message_id,
3833            session.user_id(),
3834            request.body.as_str(),
3835            &request.mentions,
3836            updated_at,
3837        )
3838        .await?;
3839
3840    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3841
3842    let message = proto::ChannelMessage {
3843        sender_id: session.user_id().to_proto(),
3844        id: message_id.to_proto(),
3845        body: request.body.clone(),
3846        mentions: request.mentions.clone(),
3847        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3848        nonce: Some(nonce),
3849        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3850        edited_at: Some(updated_at.unix_timestamp() as u64),
3851    };
3852
3853    response.send(proto::Ack {})?;
3854
3855    let pool = &*session.connection_pool().await;
3856    broadcast(
3857        Some(session.connection_id),
3858        participant_connection_ids,
3859        |connection| {
3860            session.peer.send(
3861                connection,
3862                proto::ChannelMessageUpdate {
3863                    channel_id: channel_id.to_proto(),
3864                    message: Some(message.clone()),
3865                },
3866            )?;
3867
3868            for notification_id in &deleted_mention_notification_ids {
3869                session.peer.send(
3870                    connection,
3871                    proto::DeleteNotification {
3872                        notification_id: (*notification_id).to_proto(),
3873                    },
3874                )?;
3875            }
3876
3877            for notification in &updated_mention_notifications {
3878                session.peer.send(
3879                    connection,
3880                    proto::UpdateNotification {
3881                        notification: Some(notification.clone()),
3882                    },
3883                )?;
3884            }
3885
3886            Ok(())
3887        },
3888    );
3889
3890    send_notifications(pool, &session.peer, notifications);
3891
3892    Ok(())
3893}
3894
3895/// Mark a channel message as read
3896async fn acknowledge_channel_message(
3897    request: proto::AckChannelMessage,
3898    session: Session,
3899) -> Result<()> {
3900    let channel_id = ChannelId::from_proto(request.channel_id);
3901    let message_id = MessageId::from_proto(request.message_id);
3902    let notifications = session
3903        .db()
3904        .await
3905        .observe_channel_message(channel_id, session.user_id(), message_id)
3906        .await?;
3907    send_notifications(
3908        &*session.connection_pool().await,
3909        &session.peer,
3910        notifications,
3911    );
3912    Ok(())
3913}
3914
3915/// Mark a buffer version as synced
3916async fn acknowledge_buffer_version(
3917    request: proto::AckBufferOperation,
3918    session: Session,
3919) -> Result<()> {
3920    let buffer_id = BufferId::from_proto(request.buffer_id);
3921    session
3922        .db()
3923        .await
3924        .observe_buffer_version(
3925            buffer_id,
3926            session.user_id(),
3927            request.epoch as i32,
3928            &request.version,
3929        )
3930        .await?;
3931    Ok(())
3932}
3933
3934/// Get a Supermaven API key for the user
3935async fn get_supermaven_api_key(
3936    _request: proto::GetSupermavenApiKey,
3937    response: Response<proto::GetSupermavenApiKey>,
3938    session: Session,
3939) -> Result<()> {
3940    let user_id: String = session.user_id().to_string();
3941    if !session.is_staff() {
3942        return Err(anyhow!("supermaven not enabled for this account"))?;
3943    }
3944
3945    let email = session.email().context("user must have an email")?;
3946
3947    let supermaven_admin_api = session
3948        .supermaven_client
3949        .as_ref()
3950        .context("supermaven not configured")?;
3951
3952    let result = supermaven_admin_api
3953        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3954        .await?;
3955
3956    response.send(proto::GetSupermavenApiKeyResponse {
3957        api_key: result.api_key,
3958    })?;
3959
3960    Ok(())
3961}
3962
3963/// Start receiving chat updates for a channel
3964async fn join_channel_chat(
3965    request: proto::JoinChannelChat,
3966    response: Response<proto::JoinChannelChat>,
3967    session: Session,
3968) -> Result<()> {
3969    let channel_id = ChannelId::from_proto(request.channel_id);
3970
3971    let db = session.db().await;
3972    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3973        .await?;
3974    let messages = db
3975        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3976        .await?;
3977    response.send(proto::JoinChannelChatResponse {
3978        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3979        messages,
3980    })?;
3981    Ok(())
3982}
3983
3984/// Stop receiving chat updates for a channel
3985async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3986    let channel_id = ChannelId::from_proto(request.channel_id);
3987    session
3988        .db()
3989        .await
3990        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3991        .await?;
3992    Ok(())
3993}
3994
3995/// Retrieve the chat history for a channel
3996async fn get_channel_messages(
3997    request: proto::GetChannelMessages,
3998    response: Response<proto::GetChannelMessages>,
3999    session: Session,
4000) -> Result<()> {
4001    let channel_id = ChannelId::from_proto(request.channel_id);
4002    let messages = session
4003        .db()
4004        .await
4005        .get_channel_messages(
4006            channel_id,
4007            session.user_id(),
4008            MESSAGE_COUNT_PER_PAGE,
4009            Some(MessageId::from_proto(request.before_message_id)),
4010        )
4011        .await?;
4012    response.send(proto::GetChannelMessagesResponse {
4013        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4014        messages,
4015    })?;
4016    Ok(())
4017}
4018
4019/// Retrieve specific chat messages
4020async fn get_channel_messages_by_id(
4021    request: proto::GetChannelMessagesById,
4022    response: Response<proto::GetChannelMessagesById>,
4023    session: Session,
4024) -> Result<()> {
4025    let message_ids = request
4026        .message_ids
4027        .iter()
4028        .map(|id| MessageId::from_proto(*id))
4029        .collect::<Vec<_>>();
4030    let messages = session
4031        .db()
4032        .await
4033        .get_channel_messages_by_id(session.user_id(), &message_ids)
4034        .await?;
4035    response.send(proto::GetChannelMessagesResponse {
4036        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4037        messages,
4038    })?;
4039    Ok(())
4040}
4041
4042/// Retrieve the current users notifications
4043async fn get_notifications(
4044    request: proto::GetNotifications,
4045    response: Response<proto::GetNotifications>,
4046    session: Session,
4047) -> Result<()> {
4048    let notifications = session
4049        .db()
4050        .await
4051        .get_notifications(
4052            session.user_id(),
4053            NOTIFICATION_COUNT_PER_PAGE,
4054            request.before_id.map(db::NotificationId::from_proto),
4055        )
4056        .await?;
4057    response.send(proto::GetNotificationsResponse {
4058        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4059        notifications,
4060    })?;
4061    Ok(())
4062}
4063
4064/// Mark notifications as read
4065async fn mark_notification_as_read(
4066    request: proto::MarkNotificationRead,
4067    response: Response<proto::MarkNotificationRead>,
4068    session: Session,
4069) -> Result<()> {
4070    let database = &session.db().await;
4071    let notifications = database
4072        .mark_notification_as_read_by_id(
4073            session.user_id(),
4074            NotificationId::from_proto(request.notification_id),
4075        )
4076        .await?;
4077    send_notifications(
4078        &*session.connection_pool().await,
4079        &session.peer,
4080        notifications,
4081    );
4082    response.send(proto::Ack {})?;
4083    Ok(())
4084}
4085
4086/// Get the current users information
4087async fn get_private_user_info(
4088    _request: proto::GetPrivateUserInfo,
4089    response: Response<proto::GetPrivateUserInfo>,
4090    session: Session,
4091) -> Result<()> {
4092    let db = session.db().await;
4093
4094    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4095    let user = db
4096        .get_user_by_id(session.user_id())
4097        .await?
4098        .context("user not found")?;
4099    let flags = db.get_user_flags(session.user_id()).await?;
4100
4101    response.send(proto::GetPrivateUserInfoResponse {
4102        metrics_id,
4103        staff: user.admin,
4104        flags,
4105        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4106    })?;
4107    Ok(())
4108}
4109
4110/// Accept the terms of service (tos) on behalf of the current user
4111async fn accept_terms_of_service(
4112    _request: proto::AcceptTermsOfService,
4113    response: Response<proto::AcceptTermsOfService>,
4114    session: Session,
4115) -> Result<()> {
4116    let db = session.db().await;
4117
4118    let accepted_tos_at = Utc::now();
4119    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4120        .await?;
4121
4122    response.send(proto::AcceptTermsOfServiceResponse {
4123        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4124    })?;
4125    Ok(())
4126}
4127
4128async fn get_llm_api_token(
4129    _request: proto::GetLlmToken,
4130    response: Response<proto::GetLlmToken>,
4131    session: Session,
4132) -> Result<()> {
4133    let db = session.db().await;
4134
4135    let flags = db.get_user_flags(session.user_id()).await?;
4136
4137    let user_id = session.user_id();
4138    let user = db
4139        .get_user_by_id(user_id)
4140        .await?
4141        .with_context(|| format!("user {user_id} not found"))?;
4142
4143    if user.accepted_tos_at.is_none() {
4144        Err(anyhow!("terms of service not accepted"))?
4145    }
4146
4147    let stripe_client = session
4148        .app_state
4149        .stripe_client
4150        .as_ref()
4151        .context("failed to retrieve Stripe client")?;
4152
4153    let stripe_billing = session
4154        .app_state
4155        .stripe_billing
4156        .as_ref()
4157        .context("failed to retrieve Stripe billing object")?;
4158
4159    let billing_customer = if let Some(billing_customer) =
4160        db.get_billing_customer_by_user_id(user.id).await?
4161    {
4162        billing_customer
4163    } else {
4164        let customer_id = stripe_billing
4165            .find_or_create_customer_by_email(user.email_address.as_deref())
4166            .await?;
4167
4168        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4169            .await?
4170            .context("billing customer not found")?
4171    };
4172
4173    let billing_subscription =
4174        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4175            billing_subscription
4176        } else {
4177            let stripe_customer_id =
4178                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4179
4180            let stripe_subscription = stripe_billing
4181                .subscribe_to_zed_free(stripe_customer_id)
4182                .await?;
4183
4184            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4185                billing_customer_id: billing_customer.id,
4186                kind: Some(SubscriptionKind::ZedFree),
4187                stripe_subscription_id: stripe_subscription.id.to_string(),
4188                stripe_subscription_status: stripe_subscription.status.into(),
4189                stripe_cancellation_reason: None,
4190                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4191                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4192            })
4193            .await?
4194        };
4195
4196    let billing_preferences = db.get_billing_preferences(user.id).await?;
4197
4198    let token = LlmTokenClaims::create(
4199        &user,
4200        session.is_staff(),
4201        billing_customer,
4202        billing_preferences,
4203        &flags,
4204        billing_subscription,
4205        session.system_id.clone(),
4206        &session.app_state.config,
4207    )?;
4208    response.send(proto::GetLlmTokenResponse { token })?;
4209    Ok(())
4210}
4211
4212fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4213    let message = match message {
4214        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4215        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4216        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4217        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4218        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4219            code: frame.code.into(),
4220            reason: frame.reason.as_str().to_owned().into(),
4221        })),
4222        // We should never receive a frame while reading the message, according
4223        // to the `tungstenite` maintainers:
4224        //
4225        // > It cannot occur when you read messages from the WebSocket, but it
4226        // > can be used when you want to send the raw frames (e.g. you want to
4227        // > send the frames to the WebSocket without composing the full message first).
4228        // >
4229        // > — https://github.com/snapview/tungstenite-rs/issues/268
4230        TungsteniteMessage::Frame(_) => {
4231            bail!("received an unexpected frame while reading the message")
4232        }
4233    };
4234
4235    Ok(message)
4236}
4237
4238fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4239    match message {
4240        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4241        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4242        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4243        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4244        AxumMessage::Close(frame) => {
4245            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4246                code: frame.code.into(),
4247                reason: frame.reason.as_ref().into(),
4248            }))
4249        }
4250    }
4251}
4252
4253fn notify_membership_updated(
4254    connection_pool: &mut ConnectionPool,
4255    result: MembershipUpdated,
4256    user_id: UserId,
4257    peer: &Peer,
4258) {
4259    for membership in &result.new_channels.channel_memberships {
4260        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4261    }
4262    for channel_id in &result.removed_channels {
4263        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4264    }
4265
4266    let user_channels_update = proto::UpdateUserChannels {
4267        channel_memberships: result
4268            .new_channels
4269            .channel_memberships
4270            .iter()
4271            .map(|cm| proto::ChannelMembership {
4272                channel_id: cm.channel_id.to_proto(),
4273                role: cm.role.into(),
4274            })
4275            .collect(),
4276        ..Default::default()
4277    };
4278
4279    let mut update = build_channels_update(result.new_channels);
4280    update.delete_channels = result
4281        .removed_channels
4282        .into_iter()
4283        .map(|id| id.to_proto())
4284        .collect();
4285    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4286
4287    for connection_id in connection_pool.user_connection_ids(user_id) {
4288        peer.send(connection_id, user_channels_update.clone())
4289            .trace_err();
4290        peer.send(connection_id, update.clone()).trace_err();
4291    }
4292}
4293
4294fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4295    proto::UpdateUserChannels {
4296        channel_memberships: channels
4297            .channel_memberships
4298            .iter()
4299            .map(|m| proto::ChannelMembership {
4300                channel_id: m.channel_id.to_proto(),
4301                role: m.role.into(),
4302            })
4303            .collect(),
4304        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4305        observed_channel_message_id: channels.observed_channel_messages.clone(),
4306    }
4307}
4308
4309fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4310    let mut update = proto::UpdateChannels::default();
4311
4312    for channel in channels.channels {
4313        update.channels.push(channel.to_proto());
4314    }
4315
4316    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4317    update.latest_channel_message_ids = channels.latest_channel_messages;
4318
4319    for (channel_id, participants) in channels.channel_participants {
4320        update
4321            .channel_participants
4322            .push(proto::ChannelParticipants {
4323                channel_id: channel_id.to_proto(),
4324                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4325            });
4326    }
4327
4328    for channel in channels.invited_channels {
4329        update.channel_invitations.push(channel.to_proto());
4330    }
4331
4332    update
4333}
4334
4335fn build_initial_contacts_update(
4336    contacts: Vec<db::Contact>,
4337    pool: &ConnectionPool,
4338) -> proto::UpdateContacts {
4339    let mut update = proto::UpdateContacts::default();
4340
4341    for contact in contacts {
4342        match contact {
4343            db::Contact::Accepted { user_id, busy } => {
4344                update.contacts.push(contact_for_user(user_id, busy, pool));
4345            }
4346            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4347            db::Contact::Incoming { user_id } => {
4348                update
4349                    .incoming_requests
4350                    .push(proto::IncomingContactRequest {
4351                        requester_id: user_id.to_proto(),
4352                    })
4353            }
4354        }
4355    }
4356
4357    update
4358}
4359
4360fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4361    proto::Contact {
4362        user_id: user_id.to_proto(),
4363        online: pool.is_user_online(user_id),
4364        busy,
4365    }
4366}
4367
4368fn room_updated(room: &proto::Room, peer: &Peer) {
4369    broadcast(
4370        None,
4371        room.participants
4372            .iter()
4373            .filter_map(|participant| Some(participant.peer_id?.into())),
4374        |peer_id| {
4375            peer.send(
4376                peer_id,
4377                proto::RoomUpdated {
4378                    room: Some(room.clone()),
4379                },
4380            )
4381        },
4382    );
4383}
4384
4385fn channel_updated(
4386    channel: &db::channel::Model,
4387    room: &proto::Room,
4388    peer: &Peer,
4389    pool: &ConnectionPool,
4390) {
4391    let participants = room
4392        .participants
4393        .iter()
4394        .map(|p| p.user_id)
4395        .collect::<Vec<_>>();
4396
4397    broadcast(
4398        None,
4399        pool.channel_connection_ids(channel.root_id())
4400            .filter_map(|(channel_id, role)| {
4401                role.can_see_channel(channel.visibility)
4402                    .then_some(channel_id)
4403            }),
4404        |peer_id| {
4405            peer.send(
4406                peer_id,
4407                proto::UpdateChannels {
4408                    channel_participants: vec![proto::ChannelParticipants {
4409                        channel_id: channel.id.to_proto(),
4410                        participant_user_ids: participants.clone(),
4411                    }],
4412                    ..Default::default()
4413                },
4414            )
4415        },
4416    );
4417}
4418
4419async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4420    let db = session.db().await;
4421
4422    let contacts = db.get_contacts(user_id).await?;
4423    let busy = db.is_user_busy(user_id).await?;
4424
4425    let pool = session.connection_pool().await;
4426    let updated_contact = contact_for_user(user_id, busy, &pool);
4427    for contact in contacts {
4428        if let db::Contact::Accepted {
4429            user_id: contact_user_id,
4430            ..
4431        } = contact
4432        {
4433            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4434                session
4435                    .peer
4436                    .send(
4437                        contact_conn_id,
4438                        proto::UpdateContacts {
4439                            contacts: vec![updated_contact.clone()],
4440                            remove_contacts: Default::default(),
4441                            incoming_requests: Default::default(),
4442                            remove_incoming_requests: Default::default(),
4443                            outgoing_requests: Default::default(),
4444                            remove_outgoing_requests: Default::default(),
4445                        },
4446                    )
4447                    .trace_err();
4448            }
4449        }
4450    }
4451    Ok(())
4452}
4453
4454async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4455    let mut contacts_to_update = HashSet::default();
4456
4457    let room_id;
4458    let canceled_calls_to_user_ids;
4459    let livekit_room;
4460    let delete_livekit_room;
4461    let room;
4462    let channel;
4463
4464    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4465        contacts_to_update.insert(session.user_id());
4466
4467        for project in left_room.left_projects.values() {
4468            project_left(project, session);
4469        }
4470
4471        room_id = RoomId::from_proto(left_room.room.id);
4472        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4473        livekit_room = mem::take(&mut left_room.room.livekit_room);
4474        delete_livekit_room = left_room.deleted;
4475        room = mem::take(&mut left_room.room);
4476        channel = mem::take(&mut left_room.channel);
4477
4478        room_updated(&room, &session.peer);
4479    } else {
4480        return Ok(());
4481    }
4482
4483    if let Some(channel) = channel {
4484        channel_updated(
4485            &channel,
4486            &room,
4487            &session.peer,
4488            &*session.connection_pool().await,
4489        );
4490    }
4491
4492    {
4493        let pool = session.connection_pool().await;
4494        for canceled_user_id in canceled_calls_to_user_ids {
4495            for connection_id in pool.user_connection_ids(canceled_user_id) {
4496                session
4497                    .peer
4498                    .send(
4499                        connection_id,
4500                        proto::CallCanceled {
4501                            room_id: room_id.to_proto(),
4502                        },
4503                    )
4504                    .trace_err();
4505            }
4506            contacts_to_update.insert(canceled_user_id);
4507        }
4508    }
4509
4510    for contact_user_id in contacts_to_update {
4511        update_user_contacts(contact_user_id, session).await?;
4512    }
4513
4514    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4515        live_kit
4516            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4517            .await
4518            .trace_err();
4519
4520        if delete_livekit_room {
4521            live_kit.delete_room(livekit_room).await.trace_err();
4522        }
4523    }
4524
4525    Ok(())
4526}
4527
4528async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4529    let left_channel_buffers = session
4530        .db()
4531        .await
4532        .leave_channel_buffers(session.connection_id)
4533        .await?;
4534
4535    for left_buffer in left_channel_buffers {
4536        channel_buffer_updated(
4537            session.connection_id,
4538            left_buffer.connections,
4539            &proto::UpdateChannelBufferCollaborators {
4540                channel_id: left_buffer.channel_id.to_proto(),
4541                collaborators: left_buffer.collaborators,
4542            },
4543            &session.peer,
4544        );
4545    }
4546
4547    Ok(())
4548}
4549
4550fn project_left(project: &db::LeftProject, session: &Session) {
4551    for connection_id in &project.connection_ids {
4552        if project.should_unshare {
4553            session
4554                .peer
4555                .send(
4556                    *connection_id,
4557                    proto::UnshareProject {
4558                        project_id: project.id.to_proto(),
4559                    },
4560                )
4561                .trace_err();
4562        } else {
4563            session
4564                .peer
4565                .send(
4566                    *connection_id,
4567                    proto::RemoveProjectCollaborator {
4568                        project_id: project.id.to_proto(),
4569                        peer_id: Some(session.connection_id.into()),
4570                    },
4571                )
4572                .trace_err();
4573        }
4574    }
4575}
4576
4577pub trait ResultExt {
4578    type Ok;
4579
4580    fn trace_err(self) -> Option<Self::Ok>;
4581}
4582
4583impl<T, E> ResultExt for Result<T, E>
4584where
4585    E: std::fmt::Debug,
4586{
4587    type Ok = T;
4588
4589    #[track_caller]
4590    fn trace_err(self) -> Option<T> {
4591        match self {
4592            Ok(value) => Some(value),
4593            Err(error) => {
4594                tracing::error!("{:?}", error);
4595                None
4596            }
4597        }
4598    }
4599}