rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{
   8    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
   9    MIN_ACCOUNT_AGE_FOR_LLM_USE,
  10};
  11use crate::stripe_client::StripeCustomerId;
  12use crate::{
  13    AppState, Error, Result, auth,
  14    db::{
  15        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  16        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  17        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  18        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  19    },
  20    executor::Executor,
  21};
  22use anyhow::{Context as _, anyhow, bail};
  23use async_tungstenite::tungstenite::{
  24    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  25};
  26use axum::headers::UserAgent;
  27use axum::{
  28    Extension, Router, TypedHeader,
  29    body::Body,
  30    extract::{
  31        ConnectInfo, WebSocketUpgrade,
  32        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  33    },
  34    headers::{Header, HeaderName},
  35    http::StatusCode,
  36    middleware,
  37    response::IntoResponse,
  38    routing::get,
  39};
  40use chrono::Utc;
  41use collections::{HashMap, HashSet};
  42pub use connection_pool::{ConnectionPool, ZedVersion};
  43use core::fmt::{self, Debug, Formatter};
  44use futures::TryFutureExt as _;
  45use reqwest_client::ReqwestClient;
  46use rpc::proto::{MultiLspQuery, split_repository_update};
  47use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  48use tracing::Span;
  49
  50use futures::{
  51    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  52    stream::FuturesUnordered,
  53};
  54use prometheus::{IntGauge, register_int_gauge};
  55use rpc::{
  56    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  57    proto::{
  58        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  59        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  60    },
  61};
  62use semantic_version::SemanticVersion;
  63use serde::{Serialize, Serializer};
  64use std::{
  65    any::TypeId,
  66    future::Future,
  67    marker::PhantomData,
  68    mem,
  69    net::SocketAddr,
  70    ops::{Deref, DerefMut},
  71    rc::Rc,
  72    sync::{
  73        Arc, OnceLock,
  74        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  75    },
  76    time::{Duration, Instant},
  77};
  78use time::OffsetDateTime;
  79use tokio::sync::{Semaphore, watch};
  80use tower::ServiceBuilder;
  81use tracing::{
  82    Instrument,
  83    field::{self},
  84    info_span, instrument,
  85};
  86
  87pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  88
  89// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  90pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  91
  92const MESSAGE_COUNT_PER_PAGE: usize = 100;
  93const MAX_MESSAGE_LEN: usize = 1024;
  94const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  95const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  96
  97static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  98
  99const TOTAL_DURATION_MS: &str = "total_duration_ms";
 100const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
 101const QUEUE_DURATION_MS: &str = "queue_duration_ms";
 102const HOST_WAITING_MS: &str = "host_waiting_ms";
 103
 104type MessageHandler =
 105    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
 106
 107pub struct ConnectionGuard;
 108
 109impl ConnectionGuard {
 110    pub fn try_acquire() -> Result<Self, ()> {
 111        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 112        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 113            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 114            tracing::error!(
 115                "too many concurrent connections: {}",
 116                current_connections + 1
 117            );
 118            return Err(());
 119        }
 120        Ok(ConnectionGuard)
 121    }
 122}
 123
 124impl Drop for ConnectionGuard {
 125    fn drop(&mut self) {
 126        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 127    }
 128}
 129
 130struct Response<R> {
 131    peer: Arc<Peer>,
 132    receipt: Receipt<R>,
 133    responded: Arc<AtomicBool>,
 134}
 135
 136impl<R: RequestMessage> Response<R> {
 137    fn send(self, payload: R::Response) -> Result<()> {
 138        self.responded.store(true, SeqCst);
 139        self.peer.respond(self.receipt, payload)?;
 140        Ok(())
 141    }
 142}
 143
 144#[derive(Clone, Debug)]
 145pub enum Principal {
 146    User(User),
 147    Impersonated { user: User, admin: User },
 148}
 149
 150impl Principal {
 151    fn user(&self) -> &User {
 152        match self {
 153            Principal::User(user) => user,
 154            Principal::Impersonated { user, .. } => user,
 155        }
 156    }
 157
 158    fn update_span(&self, span: &tracing::Span) {
 159        match &self {
 160            Principal::User(user) => {
 161                span.record("user_id", user.id.0);
 162                span.record("login", &user.github_login);
 163            }
 164            Principal::Impersonated { user, admin } => {
 165                span.record("user_id", user.id.0);
 166                span.record("login", &user.github_login);
 167                span.record("impersonator", &admin.github_login);
 168            }
 169        }
 170    }
 171}
 172
 173#[derive(Clone)]
 174struct MessageContext {
 175    session: Session,
 176    span: tracing::Span,
 177}
 178
 179impl Deref for MessageContext {
 180    type Target = Session;
 181
 182    fn deref(&self) -> &Self::Target {
 183        &self.session
 184    }
 185}
 186
 187impl MessageContext {
 188    pub fn forward_request<T: RequestMessage>(
 189        &self,
 190        receiver_id: ConnectionId,
 191        request: T,
 192    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 193        let request_start_time = Instant::now();
 194        let span = self.span.clone();
 195        tracing::info!("start forwarding request");
 196        self.peer
 197            .forward_request(self.connection_id, receiver_id, request)
 198            .inspect(move |_| {
 199                span.record(
 200                    HOST_WAITING_MS,
 201                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 202                );
 203            })
 204            .inspect_err(|_| tracing::error!("error forwarding request"))
 205            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 206    }
 207}
 208
 209#[derive(Clone)]
 210struct Session {
 211    principal: Principal,
 212    connection_id: ConnectionId,
 213    db: Arc<tokio::sync::Mutex<DbHandle>>,
 214    peer: Arc<Peer>,
 215    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 216    app_state: Arc<AppState>,
 217    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 218    /// The GeoIP country code for the user.
 219    #[allow(unused)]
 220    geoip_country_code: Option<String>,
 221    system_id: Option<String>,
 222    _executor: Executor,
 223}
 224
 225impl Session {
 226    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 227        #[cfg(test)]
 228        tokio::task::yield_now().await;
 229        let guard = self.db.lock().await;
 230        #[cfg(test)]
 231        tokio::task::yield_now().await;
 232        guard
 233    }
 234
 235    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 236        #[cfg(test)]
 237        tokio::task::yield_now().await;
 238        let guard = self.connection_pool.lock();
 239        ConnectionPoolGuard {
 240            guard,
 241            _not_send: PhantomData,
 242        }
 243    }
 244
 245    fn is_staff(&self) -> bool {
 246        match &self.principal {
 247            Principal::User(user) => user.admin,
 248            Principal::Impersonated { .. } => true,
 249        }
 250    }
 251
 252    fn user_id(&self) -> UserId {
 253        match &self.principal {
 254            Principal::User(user) => user.id,
 255            Principal::Impersonated { user, .. } => user.id,
 256        }
 257    }
 258
 259    pub fn email(&self) -> Option<String> {
 260        match &self.principal {
 261            Principal::User(user) => user.email_address.clone(),
 262            Principal::Impersonated { user, .. } => user.email_address.clone(),
 263        }
 264    }
 265}
 266
 267impl Debug for Session {
 268    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 269        let mut result = f.debug_struct("Session");
 270        match &self.principal {
 271            Principal::User(user) => {
 272                result.field("user", &user.github_login);
 273            }
 274            Principal::Impersonated { user, admin } => {
 275                result.field("user", &user.github_login);
 276                result.field("impersonator", &admin.github_login);
 277            }
 278        }
 279        result.field("connection_id", &self.connection_id).finish()
 280    }
 281}
 282
 283struct DbHandle(Arc<Database>);
 284
 285impl Deref for DbHandle {
 286    type Target = Database;
 287
 288    fn deref(&self) -> &Self::Target {
 289        self.0.as_ref()
 290    }
 291}
 292
 293pub struct Server {
 294    id: parking_lot::Mutex<ServerId>,
 295    peer: Arc<Peer>,
 296    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 297    app_state: Arc<AppState>,
 298    handlers: HashMap<TypeId, MessageHandler>,
 299    teardown: watch::Sender<bool>,
 300}
 301
 302pub(crate) struct ConnectionPoolGuard<'a> {
 303    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 304    _not_send: PhantomData<Rc<()>>,
 305}
 306
 307#[derive(Serialize)]
 308pub struct ServerSnapshot<'a> {
 309    peer: &'a Peer,
 310    #[serde(serialize_with = "serialize_deref")]
 311    connection_pool: ConnectionPoolGuard<'a>,
 312}
 313
 314pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 315where
 316    S: Serializer,
 317    T: Deref<Target = U>,
 318    U: Serialize,
 319{
 320    Serialize::serialize(value.deref(), serializer)
 321}
 322
 323impl Server {
 324    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 325        let mut server = Self {
 326            id: parking_lot::Mutex::new(id),
 327            peer: Peer::new(id.0 as u32),
 328            app_state: app_state.clone(),
 329            connection_pool: Default::default(),
 330            handlers: Default::default(),
 331            teardown: watch::channel(false).0,
 332        };
 333
 334        server
 335            .add_request_handler(ping)
 336            .add_request_handler(create_room)
 337            .add_request_handler(join_room)
 338            .add_request_handler(rejoin_room)
 339            .add_request_handler(leave_room)
 340            .add_request_handler(set_room_participant_role)
 341            .add_request_handler(call)
 342            .add_request_handler(cancel_call)
 343            .add_message_handler(decline_call)
 344            .add_request_handler(update_participant_location)
 345            .add_request_handler(share_project)
 346            .add_message_handler(unshare_project)
 347            .add_request_handler(join_project)
 348            .add_message_handler(leave_project)
 349            .add_request_handler(update_project)
 350            .add_request_handler(update_worktree)
 351            .add_request_handler(update_repository)
 352            .add_request_handler(remove_repository)
 353            .add_message_handler(start_language_server)
 354            .add_message_handler(update_language_server)
 355            .add_message_handler(update_diagnostic_summary)
 356            .add_message_handler(update_worktree_settings)
 357            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 358            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 359            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 360            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 361            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 362            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 363            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 364            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 365            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 366            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 367            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 368            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 369            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 370            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 371            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 372            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 373            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 374            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 375            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 376            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 377            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 378            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 379            .add_request_handler(
 380                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 381            )
 382            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 383            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 384            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 385            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 386            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 387            .add_request_handler(
 388                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 389            )
 390            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 391            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 392            .add_request_handler(
 393                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 394            )
 395            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 396            .add_request_handler(
 397                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 398            )
 399            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 400            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 401            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 402            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 403            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 404            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 405            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 406            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 407            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 408            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 409            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 410            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 411            .add_request_handler(
 412                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 413            )
 414            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 415            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 416            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 417            .add_request_handler(multi_lsp_query)
 418            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 419            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 420            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 421            .add_message_handler(create_buffer_for_peer)
 422            .add_request_handler(update_buffer)
 423            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 424            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 425            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 426            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 427            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 428            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 429            .add_message_handler(
 430                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 431            )
 432            .add_request_handler(get_users)
 433            .add_request_handler(fuzzy_search_users)
 434            .add_request_handler(request_contact)
 435            .add_request_handler(remove_contact)
 436            .add_request_handler(respond_to_contact_request)
 437            .add_message_handler(subscribe_to_channels)
 438            .add_request_handler(create_channel)
 439            .add_request_handler(delete_channel)
 440            .add_request_handler(invite_channel_member)
 441            .add_request_handler(remove_channel_member)
 442            .add_request_handler(set_channel_member_role)
 443            .add_request_handler(set_channel_visibility)
 444            .add_request_handler(rename_channel)
 445            .add_request_handler(join_channel_buffer)
 446            .add_request_handler(leave_channel_buffer)
 447            .add_message_handler(update_channel_buffer)
 448            .add_request_handler(rejoin_channel_buffers)
 449            .add_request_handler(get_channel_members)
 450            .add_request_handler(respond_to_channel_invite)
 451            .add_request_handler(join_channel)
 452            .add_request_handler(join_channel_chat)
 453            .add_message_handler(leave_channel_chat)
 454            .add_request_handler(send_channel_message)
 455            .add_request_handler(remove_channel_message)
 456            .add_request_handler(update_channel_message)
 457            .add_request_handler(get_channel_messages)
 458            .add_request_handler(get_channel_messages_by_id)
 459            .add_request_handler(get_notifications)
 460            .add_request_handler(mark_notification_as_read)
 461            .add_request_handler(move_channel)
 462            .add_request_handler(reorder_channel)
 463            .add_request_handler(follow)
 464            .add_message_handler(unfollow)
 465            .add_message_handler(update_followers)
 466            .add_request_handler(get_private_user_info)
 467            .add_request_handler(get_llm_api_token)
 468            .add_request_handler(accept_terms_of_service)
 469            .add_message_handler(acknowledge_channel_message)
 470            .add_message_handler(acknowledge_buffer_version)
 471            .add_request_handler(get_supermaven_api_key)
 472            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 473            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 474            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 475            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 476            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 477            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 478            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 479            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 480            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 481            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 482            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 483            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 484            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 485            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 486            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 487            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 488            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 489            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 490            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 491            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 492            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 493            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 494            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 495            .add_message_handler(update_context);
 496
 497        Arc::new(server)
 498    }
 499
 500    pub async fn start(&self) -> Result<()> {
 501        let server_id = *self.id.lock();
 502        let app_state = self.app_state.clone();
 503        let peer = self.peer.clone();
 504        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 505        let pool = self.connection_pool.clone();
 506        let livekit_client = self.app_state.livekit_client.clone();
 507
 508        let span = info_span!("start server");
 509        self.app_state.executor.spawn_detached(
 510            async move {
 511                tracing::info!("waiting for cleanup timeout");
 512                timeout.await;
 513                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 514
 515                app_state
 516                    .db
 517                    .delete_stale_channel_chat_participants(
 518                        &app_state.config.zed_environment,
 519                        server_id,
 520                    )
 521                    .await
 522                    .trace_err();
 523
 524                if let Some((room_ids, channel_ids)) = app_state
 525                    .db
 526                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 527                    .await
 528                    .trace_err()
 529                {
 530                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 531                    tracing::info!(
 532                        stale_channel_buffer_count = channel_ids.len(),
 533                        "retrieved stale channel buffers"
 534                    );
 535
 536                    for channel_id in channel_ids {
 537                        if let Some(refreshed_channel_buffer) = app_state
 538                            .db
 539                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 540                            .await
 541                            .trace_err()
 542                        {
 543                            for connection_id in refreshed_channel_buffer.connection_ids {
 544                                peer.send(
 545                                    connection_id,
 546                                    proto::UpdateChannelBufferCollaborators {
 547                                        channel_id: channel_id.to_proto(),
 548                                        collaborators: refreshed_channel_buffer
 549                                            .collaborators
 550                                            .clone(),
 551                                    },
 552                                )
 553                                .trace_err();
 554                            }
 555                        }
 556                    }
 557
 558                    for room_id in room_ids {
 559                        let mut contacts_to_update = HashSet::default();
 560                        let mut canceled_calls_to_user_ids = Vec::new();
 561                        let mut livekit_room = String::new();
 562                        let mut delete_livekit_room = false;
 563
 564                        if let Some(mut refreshed_room) = app_state
 565                            .db
 566                            .clear_stale_room_participants(room_id, server_id)
 567                            .await
 568                            .trace_err()
 569                        {
 570                            tracing::info!(
 571                                room_id = room_id.0,
 572                                new_participant_count = refreshed_room.room.participants.len(),
 573                                "refreshed room"
 574                            );
 575                            room_updated(&refreshed_room.room, &peer);
 576                            if let Some(channel) = refreshed_room.channel.as_ref() {
 577                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 578                            }
 579                            contacts_to_update
 580                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 581                            contacts_to_update
 582                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 583                            canceled_calls_to_user_ids =
 584                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 585                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 586                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 587                        }
 588
 589                        {
 590                            let pool = pool.lock();
 591                            for canceled_user_id in canceled_calls_to_user_ids {
 592                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 593                                    peer.send(
 594                                        connection_id,
 595                                        proto::CallCanceled {
 596                                            room_id: room_id.to_proto(),
 597                                        },
 598                                    )
 599                                    .trace_err();
 600                                }
 601                            }
 602                        }
 603
 604                        for user_id in contacts_to_update {
 605                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 606                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 607                            if let Some((busy, contacts)) = busy.zip(contacts) {
 608                                let pool = pool.lock();
 609                                let updated_contact = contact_for_user(user_id, busy, &pool);
 610                                for contact in contacts {
 611                                    if let db::Contact::Accepted {
 612                                        user_id: contact_user_id,
 613                                        ..
 614                                    } = contact
 615                                    {
 616                                        for contact_conn_id in
 617                                            pool.user_connection_ids(contact_user_id)
 618                                        {
 619                                            peer.send(
 620                                                contact_conn_id,
 621                                                proto::UpdateContacts {
 622                                                    contacts: vec![updated_contact.clone()],
 623                                                    remove_contacts: Default::default(),
 624                                                    incoming_requests: Default::default(),
 625                                                    remove_incoming_requests: Default::default(),
 626                                                    outgoing_requests: Default::default(),
 627                                                    remove_outgoing_requests: Default::default(),
 628                                                },
 629                                            )
 630                                            .trace_err();
 631                                        }
 632                                    }
 633                                }
 634                            }
 635                        }
 636
 637                        if let Some(live_kit) = livekit_client.as_ref() {
 638                            if delete_livekit_room {
 639                                live_kit.delete_room(livekit_room).await.trace_err();
 640                            }
 641                        }
 642                    }
 643                }
 644
 645                app_state
 646                    .db
 647                    .delete_stale_channel_chat_participants(
 648                        &app_state.config.zed_environment,
 649                        server_id,
 650                    )
 651                    .await
 652                    .trace_err();
 653
 654                app_state
 655                    .db
 656                    .clear_old_worktree_entries(server_id)
 657                    .await
 658                    .trace_err();
 659
 660                app_state
 661                    .db
 662                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 663                    .await
 664                    .trace_err();
 665            }
 666            .instrument(span),
 667        );
 668        Ok(())
 669    }
 670
 671    pub fn teardown(&self) {
 672        self.peer.teardown();
 673        self.connection_pool.lock().reset();
 674        let _ = self.teardown.send(true);
 675    }
 676
 677    #[cfg(test)]
 678    pub fn reset(&self, id: ServerId) {
 679        self.teardown();
 680        *self.id.lock() = id;
 681        self.peer.reset(id.0 as u32);
 682        let _ = self.teardown.send(false);
 683    }
 684
 685    #[cfg(test)]
 686    pub fn id(&self) -> ServerId {
 687        *self.id.lock()
 688    }
 689
 690    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 691    where
 692        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 693        Fut: 'static + Send + Future<Output = Result<()>>,
 694        M: EnvelopedMessage,
 695    {
 696        let prev_handler = self.handlers.insert(
 697            TypeId::of::<M>(),
 698            Box::new(move |envelope, session, span| {
 699                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 700                let received_at = envelope.received_at;
 701                tracing::info!("message received");
 702                let start_time = Instant::now();
 703                let future = (handler)(
 704                    *envelope,
 705                    MessageContext {
 706                        session,
 707                        span: span.clone(),
 708                    },
 709                );
 710                async move {
 711                    let result = future.await;
 712                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 713                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 714                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 715                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 716                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 717                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 718                    match result {
 719                        Err(error) => {
 720                            tracing::error!(?error, "error handling message")
 721                        }
 722                        Ok(()) => tracing::info!("finished handling message"),
 723                    }
 724                }
 725                .boxed()
 726            }),
 727        );
 728        if prev_handler.is_some() {
 729            panic!("registered a handler for the same message twice");
 730        }
 731        self
 732    }
 733
 734    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 735    where
 736        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 737        Fut: 'static + Send + Future<Output = Result<()>>,
 738        M: EnvelopedMessage,
 739    {
 740        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 741        self
 742    }
 743
 744    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 745    where
 746        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 747        Fut: Send + Future<Output = Result<()>>,
 748        M: RequestMessage,
 749    {
 750        let handler = Arc::new(handler);
 751        self.add_handler(move |envelope, session| {
 752            let receipt = envelope.receipt();
 753            let handler = handler.clone();
 754            async move {
 755                let peer = session.peer.clone();
 756                let responded = Arc::new(AtomicBool::default());
 757                let response = Response {
 758                    peer: peer.clone(),
 759                    responded: responded.clone(),
 760                    receipt,
 761                };
 762                match (handler)(envelope.payload, response, session).await {
 763                    Ok(()) => {
 764                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 765                            Ok(())
 766                        } else {
 767                            Err(anyhow!("handler did not send a response"))?
 768                        }
 769                    }
 770                    Err(error) => {
 771                        let proto_err = match &error {
 772                            Error::Internal(err) => err.to_proto(),
 773                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 774                        };
 775                        peer.respond_with_error(receipt, proto_err)?;
 776                        Err(error)
 777                    }
 778                }
 779            }
 780        })
 781    }
 782
 783    pub fn handle_connection(
 784        self: &Arc<Self>,
 785        connection: Connection,
 786        address: String,
 787        principal: Principal,
 788        zed_version: ZedVersion,
 789        release_channel: Option<String>,
 790        user_agent: Option<String>,
 791        geoip_country_code: Option<String>,
 792        system_id: Option<String>,
 793        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 794        executor: Executor,
 795        connection_guard: Option<ConnectionGuard>,
 796    ) -> impl Future<Output = ()> + use<> {
 797        let this = self.clone();
 798        let span = info_span!("handle connection", %address,
 799            connection_id=field::Empty,
 800            user_id=field::Empty,
 801            login=field::Empty,
 802            impersonator=field::Empty,
 803            user_agent=field::Empty,
 804            geoip_country_code=field::Empty,
 805            release_channel=field::Empty,
 806        );
 807        principal.update_span(&span);
 808        if let Some(user_agent) = user_agent {
 809            span.record("user_agent", user_agent);
 810        }
 811        if let Some(release_channel) = release_channel {
 812            span.record("release_channel", release_channel);
 813        }
 814
 815        if let Some(country_code) = geoip_country_code.as_ref() {
 816            span.record("geoip_country_code", country_code);
 817        }
 818
 819        let mut teardown = self.teardown.subscribe();
 820        async move {
 821            if *teardown.borrow() {
 822                tracing::error!("server is tearing down");
 823                return;
 824            }
 825
 826            let (connection_id, handle_io, mut incoming_rx) =
 827                this.peer.add_connection(connection, {
 828                    let executor = executor.clone();
 829                    move |duration| executor.sleep(duration)
 830                });
 831            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 832
 833            tracing::info!("connection opened");
 834
 835            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 836            let http_client = match ReqwestClient::user_agent(&user_agent) {
 837                Ok(http_client) => Arc::new(http_client),
 838                Err(error) => {
 839                    tracing::error!(?error, "failed to create HTTP client");
 840                    return;
 841                }
 842            };
 843
 844            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 845                |supermaven_admin_api_key| {
 846                    Arc::new(SupermavenAdminApi::new(
 847                        supermaven_admin_api_key.to_string(),
 848                        http_client.clone(),
 849                    ))
 850                },
 851            );
 852
 853            let session = Session {
 854                principal: principal.clone(),
 855                connection_id,
 856                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 857                peer: this.peer.clone(),
 858                connection_pool: this.connection_pool.clone(),
 859                app_state: this.app_state.clone(),
 860                geoip_country_code,
 861                system_id,
 862                _executor: executor.clone(),
 863                supermaven_client,
 864            };
 865
 866            if let Err(error) = this
 867                .send_initial_client_update(
 868                    connection_id,
 869                    zed_version,
 870                    send_connection_id,
 871                    &session,
 872                )
 873                .await
 874            {
 875                tracing::error!(?error, "failed to send initial client update");
 876                return;
 877            }
 878            drop(connection_guard);
 879
 880            let handle_io = handle_io.fuse();
 881            futures::pin_mut!(handle_io);
 882
 883            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 884            // This prevents deadlocks when e.g., client A performs a request to client B and
 885            // client B performs a request to client A. If both clients stop processing further
 886            // messages until their respective request completes, they won't have a chance to
 887            // respond to the other client's request and cause a deadlock.
 888            //
 889            // This arrangement ensures we will attempt to process earlier messages first, but fall
 890            // back to processing messages arrived later in the spirit of making progress.
 891            const MAX_CONCURRENT_HANDLERS: usize = 256;
 892            let mut foreground_message_handlers = FuturesUnordered::new();
 893            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 894            let get_concurrent_handlers = {
 895                let concurrent_handlers = concurrent_handlers.clone();
 896                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 897            };
 898            loop {
 899                let next_message = async {
 900                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 901                    let message = incoming_rx.next().await;
 902                    // Cache the concurrent_handlers here, so that we know what the
 903                    // queue looks like as each handler starts
 904                    (permit, message, get_concurrent_handlers())
 905                }
 906                .fuse();
 907                futures::pin_mut!(next_message);
 908                futures::select_biased! {
 909                    _ = teardown.changed().fuse() => return,
 910                    result = handle_io => {
 911                        if let Err(error) = result {
 912                            tracing::error!(?error, "error handling I/O");
 913                        }
 914                        break;
 915                    }
 916                    _ = foreground_message_handlers.next() => {}
 917                    next_message = next_message => {
 918                        let (permit, message, concurrent_handlers) = next_message;
 919                        if let Some(message) = message {
 920                            let type_name = message.payload_type_name();
 921                            // note: we copy all the fields from the parent span so we can query them in the logs.
 922                            // (https://github.com/tokio-rs/tracing/issues/2670).
 923                            let span = tracing::info_span!("receive message",
 924                                %connection_id,
 925                                %address,
 926                                type_name,
 927                                concurrent_handlers,
 928                                user_id=field::Empty,
 929                                login=field::Empty,
 930                                impersonator=field::Empty,
 931                                multi_lsp_query_request=field::Empty,
 932                                { TOTAL_DURATION_MS }=field::Empty,
 933                                { PROCESSING_DURATION_MS }=field::Empty,
 934                                { QUEUE_DURATION_MS }=field::Empty,
 935                                { HOST_WAITING_MS }=field::Empty
 936                            );
 937                            principal.update_span(&span);
 938                            let span_enter = span.enter();
 939                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 940                                let is_background = message.is_background();
 941                                let handle_message = (handler)(message, session.clone(), span.clone());
 942                                drop(span_enter);
 943
 944                                let handle_message = async move {
 945                                    handle_message.await;
 946                                    drop(permit);
 947                                }.instrument(span);
 948                                if is_background {
 949                                    executor.spawn_detached(handle_message);
 950                                } else {
 951                                    foreground_message_handlers.push(handle_message);
 952                                }
 953                            } else {
 954                                tracing::error!("no message handler");
 955                            }
 956                        } else {
 957                            tracing::info!("connection closed");
 958                            break;
 959                        }
 960                    }
 961                }
 962            }
 963
 964            drop(foreground_message_handlers);
 965            let concurrent_handlers = get_concurrent_handlers();
 966            tracing::info!(concurrent_handlers, "signing out");
 967            if let Err(error) = connection_lost(session, teardown, executor).await {
 968                tracing::error!(?error, "error signing out");
 969            }
 970        }
 971        .instrument(span)
 972    }
 973
 974    async fn send_initial_client_update(
 975        &self,
 976        connection_id: ConnectionId,
 977        zed_version: ZedVersion,
 978        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 979        session: &Session,
 980    ) -> Result<()> {
 981        self.peer.send(
 982            connection_id,
 983            proto::Hello {
 984                peer_id: Some(connection_id.into()),
 985            },
 986        )?;
 987        tracing::info!("sent hello message");
 988        if let Some(send_connection_id) = send_connection_id.take() {
 989            let _ = send_connection_id.send(connection_id);
 990        }
 991
 992        match &session.principal {
 993            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 994                if !user.connected_once {
 995                    self.peer.send(connection_id, proto::ShowContacts {})?;
 996                    self.app_state
 997                        .db
 998                        .set_user_connected_once(user.id, true)
 999                        .await?;
1000                }
1001
1002                update_user_plan(session).await?;
1003
1004                let contacts = self.app_state.db.get_contacts(user.id).await?;
1005
1006                {
1007                    let mut pool = self.connection_pool.lock();
1008                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1009                    self.peer.send(
1010                        connection_id,
1011                        build_initial_contacts_update(contacts, &pool),
1012                    )?;
1013                }
1014
1015                if should_auto_subscribe_to_channels(zed_version) {
1016                    subscribe_user_to_channels(user.id, session).await?;
1017                }
1018
1019                if let Some(incoming_call) =
1020                    self.app_state.db.incoming_call_for_user(user.id).await?
1021                {
1022                    self.peer.send(connection_id, incoming_call)?;
1023                }
1024
1025                update_user_contacts(user.id, session).await?;
1026            }
1027        }
1028
1029        Ok(())
1030    }
1031
1032    pub async fn invite_code_redeemed(
1033        self: &Arc<Self>,
1034        inviter_id: UserId,
1035        invitee_id: UserId,
1036    ) -> Result<()> {
1037        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1038            if let Some(code) = &user.invite_code {
1039                let pool = self.connection_pool.lock();
1040                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1041                for connection_id in pool.user_connection_ids(inviter_id) {
1042                    self.peer.send(
1043                        connection_id,
1044                        proto::UpdateContacts {
1045                            contacts: vec![invitee_contact.clone()],
1046                            ..Default::default()
1047                        },
1048                    )?;
1049                    self.peer.send(
1050                        connection_id,
1051                        proto::UpdateInviteInfo {
1052                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1053                            count: user.invite_count as u32,
1054                        },
1055                    )?;
1056                }
1057            }
1058        }
1059        Ok(())
1060    }
1061
1062    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1063        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1064            if let Some(invite_code) = &user.invite_code {
1065                let pool = self.connection_pool.lock();
1066                for connection_id in pool.user_connection_ids(user_id) {
1067                    self.peer.send(
1068                        connection_id,
1069                        proto::UpdateInviteInfo {
1070                            url: format!(
1071                                "{}{}",
1072                                self.app_state.config.invite_link_prefix, invite_code
1073                            ),
1074                            count: user.invite_count as u32,
1075                        },
1076                    )?;
1077                }
1078            }
1079        }
1080        Ok(())
1081    }
1082
1083    pub async fn update_plan_for_user(
1084        self: &Arc<Self>,
1085        user_id: UserId,
1086        update_user_plan: proto::UpdateUserPlan,
1087    ) -> Result<()> {
1088        let pool = self.connection_pool.lock();
1089        for connection_id in pool.user_connection_ids(user_id) {
1090            self.peer
1091                .send(connection_id, update_user_plan.clone())
1092                .trace_err();
1093        }
1094
1095        Ok(())
1096    }
1097
1098    /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1099    /// message on the Collab server.
1100    ///
1101    /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1102    pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1103        let user = self
1104            .app_state
1105            .db
1106            .get_user_by_id(user_id)
1107            .await?
1108            .context("user not found")?;
1109
1110        let update_user_plan = make_update_user_plan_message(
1111            &user,
1112            user.admin,
1113            &self.app_state.db,
1114            self.app_state.llm_db.clone(),
1115        )
1116        .await?;
1117
1118        self.update_plan_for_user(user_id, update_user_plan).await
1119    }
1120
1121    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1122        let pool = self.connection_pool.lock();
1123        for connection_id in pool.user_connection_ids(user_id) {
1124            self.peer
1125                .send(connection_id, proto::RefreshLlmToken {})
1126                .trace_err();
1127        }
1128    }
1129
1130    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1131        ServerSnapshot {
1132            connection_pool: ConnectionPoolGuard {
1133                guard: self.connection_pool.lock(),
1134                _not_send: PhantomData,
1135            },
1136            peer: &self.peer,
1137        }
1138    }
1139}
1140
1141impl Deref for ConnectionPoolGuard<'_> {
1142    type Target = ConnectionPool;
1143
1144    fn deref(&self) -> &Self::Target {
1145        &self.guard
1146    }
1147}
1148
1149impl DerefMut for ConnectionPoolGuard<'_> {
1150    fn deref_mut(&mut self) -> &mut Self::Target {
1151        &mut self.guard
1152    }
1153}
1154
1155impl Drop for ConnectionPoolGuard<'_> {
1156    fn drop(&mut self) {
1157        #[cfg(test)]
1158        self.check_invariants();
1159    }
1160}
1161
1162fn broadcast<F>(
1163    sender_id: Option<ConnectionId>,
1164    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1165    mut f: F,
1166) where
1167    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1168{
1169    for receiver_id in receiver_ids {
1170        if Some(receiver_id) != sender_id {
1171            if let Err(error) = f(receiver_id) {
1172                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1173            }
1174        }
1175    }
1176}
1177
1178pub struct ProtocolVersion(u32);
1179
1180impl Header for ProtocolVersion {
1181    fn name() -> &'static HeaderName {
1182        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1183        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1184    }
1185
1186    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1187    where
1188        Self: Sized,
1189        I: Iterator<Item = &'i axum::http::HeaderValue>,
1190    {
1191        let version = values
1192            .next()
1193            .ok_or_else(axum::headers::Error::invalid)?
1194            .to_str()
1195            .map_err(|_| axum::headers::Error::invalid())?
1196            .parse()
1197            .map_err(|_| axum::headers::Error::invalid())?;
1198        Ok(Self(version))
1199    }
1200
1201    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1202        values.extend([self.0.to_string().parse().unwrap()]);
1203    }
1204}
1205
1206pub struct AppVersionHeader(SemanticVersion);
1207impl Header for AppVersionHeader {
1208    fn name() -> &'static HeaderName {
1209        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1210        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1211    }
1212
1213    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1214    where
1215        Self: Sized,
1216        I: Iterator<Item = &'i axum::http::HeaderValue>,
1217    {
1218        let version = values
1219            .next()
1220            .ok_or_else(axum::headers::Error::invalid)?
1221            .to_str()
1222            .map_err(|_| axum::headers::Error::invalid())?
1223            .parse()
1224            .map_err(|_| axum::headers::Error::invalid())?;
1225        Ok(Self(version))
1226    }
1227
1228    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1229        values.extend([self.0.to_string().parse().unwrap()]);
1230    }
1231}
1232
1233#[derive(Debug)]
1234pub struct ReleaseChannelHeader(String);
1235
1236impl Header for ReleaseChannelHeader {
1237    fn name() -> &'static HeaderName {
1238        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1239        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1240    }
1241
1242    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1243    where
1244        Self: Sized,
1245        I: Iterator<Item = &'i axum::http::HeaderValue>,
1246    {
1247        Ok(Self(
1248            values
1249                .next()
1250                .ok_or_else(axum::headers::Error::invalid)?
1251                .to_str()
1252                .map_err(|_| axum::headers::Error::invalid())?
1253                .to_owned(),
1254        ))
1255    }
1256
1257    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1258        values.extend([self.0.parse().unwrap()]);
1259    }
1260}
1261
1262pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1263    Router::new()
1264        .route("/rpc", get(handle_websocket_request))
1265        .layer(
1266            ServiceBuilder::new()
1267                .layer(Extension(server.app_state.clone()))
1268                .layer(middleware::from_fn(auth::validate_header)),
1269        )
1270        .route("/metrics", get(handle_metrics))
1271        .layer(Extension(server))
1272}
1273
1274pub async fn handle_websocket_request(
1275    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1276    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1277    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1278    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1279    Extension(server): Extension<Arc<Server>>,
1280    Extension(principal): Extension<Principal>,
1281    user_agent: Option<TypedHeader<UserAgent>>,
1282    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1283    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1284    ws: WebSocketUpgrade,
1285) -> axum::response::Response {
1286    if protocol_version != rpc::PROTOCOL_VERSION {
1287        return (
1288            StatusCode::UPGRADE_REQUIRED,
1289            "client must be upgraded".to_string(),
1290        )
1291            .into_response();
1292    }
1293
1294    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1295        return (
1296            StatusCode::UPGRADE_REQUIRED,
1297            "no version header found".to_string(),
1298        )
1299            .into_response();
1300    };
1301
1302    let release_channel = release_channel_header.map(|header| header.0.0);
1303
1304    if !version.can_collaborate() {
1305        return (
1306            StatusCode::UPGRADE_REQUIRED,
1307            "client must be upgraded".to_string(),
1308        )
1309            .into_response();
1310    }
1311
1312    let socket_address = socket_address.to_string();
1313
1314    // Acquire connection guard before WebSocket upgrade
1315    let connection_guard = match ConnectionGuard::try_acquire() {
1316        Ok(guard) => guard,
1317        Err(()) => {
1318            return (
1319                StatusCode::SERVICE_UNAVAILABLE,
1320                "Too many concurrent connections",
1321            )
1322                .into_response();
1323        }
1324    };
1325
1326    ws.on_upgrade(move |socket| {
1327        let socket = socket
1328            .map_ok(to_tungstenite_message)
1329            .err_into()
1330            .with(|message| async move { to_axum_message(message) });
1331        let connection = Connection::new(Box::pin(socket));
1332        async move {
1333            server
1334                .handle_connection(
1335                    connection,
1336                    socket_address,
1337                    principal,
1338                    version,
1339                    release_channel,
1340                    user_agent.map(|header| header.to_string()),
1341                    country_code_header.map(|header| header.to_string()),
1342                    system_id_header.map(|header| header.to_string()),
1343                    None,
1344                    Executor::Production,
1345                    Some(connection_guard),
1346                )
1347                .await;
1348        }
1349    })
1350}
1351
1352pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1353    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1354    let connections_metric = CONNECTIONS_METRIC
1355        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1356
1357    let connections = server
1358        .connection_pool
1359        .lock()
1360        .connections()
1361        .filter(|connection| !connection.admin)
1362        .count();
1363    connections_metric.set(connections as _);
1364
1365    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1366    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1367        register_int_gauge!(
1368            "shared_projects",
1369            "number of open projects with one or more guests"
1370        )
1371        .unwrap()
1372    });
1373
1374    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1375    shared_projects_metric.set(shared_projects as _);
1376
1377    let encoder = prometheus::TextEncoder::new();
1378    let metric_families = prometheus::gather();
1379    let encoded_metrics = encoder
1380        .encode_to_string(&metric_families)
1381        .map_err(|err| anyhow!("{err}"))?;
1382    Ok(encoded_metrics)
1383}
1384
1385#[instrument(err, skip(executor))]
1386async fn connection_lost(
1387    session: Session,
1388    mut teardown: watch::Receiver<bool>,
1389    executor: Executor,
1390) -> Result<()> {
1391    session.peer.disconnect(session.connection_id);
1392    session
1393        .connection_pool()
1394        .await
1395        .remove_connection(session.connection_id)?;
1396
1397    session
1398        .db()
1399        .await
1400        .connection_lost(session.connection_id)
1401        .await
1402        .trace_err();
1403
1404    futures::select_biased! {
1405        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1406
1407            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1408            leave_room_for_session(&session, session.connection_id).await.trace_err();
1409            leave_channel_buffers_for_session(&session)
1410                .await
1411                .trace_err();
1412
1413            if !session
1414                .connection_pool()
1415                .await
1416                .is_user_online(session.user_id())
1417            {
1418                let db = session.db().await;
1419                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1420                    room_updated(&room, &session.peer);
1421                }
1422            }
1423
1424            update_user_contacts(session.user_id(), &session).await?;
1425        },
1426        _ = teardown.changed().fuse() => {}
1427    }
1428
1429    Ok(())
1430}
1431
1432/// Acknowledges a ping from a client, used to keep the connection alive.
1433async fn ping(
1434    _: proto::Ping,
1435    response: Response<proto::Ping>,
1436    _session: MessageContext,
1437) -> Result<()> {
1438    response.send(proto::Ack {})?;
1439    Ok(())
1440}
1441
1442/// Creates a new room for calling (outside of channels)
1443async fn create_room(
1444    _request: proto::CreateRoom,
1445    response: Response<proto::CreateRoom>,
1446    session: MessageContext,
1447) -> Result<()> {
1448    let livekit_room = nanoid::nanoid!(30);
1449
1450    let live_kit_connection_info = util::maybe!(async {
1451        let live_kit = session.app_state.livekit_client.as_ref();
1452        let live_kit = live_kit?;
1453        let user_id = session.user_id().to_string();
1454
1455        let token = live_kit
1456            .room_token(&livekit_room, &user_id.to_string())
1457            .trace_err()?;
1458
1459        Some(proto::LiveKitConnectionInfo {
1460            server_url: live_kit.url().into(),
1461            token,
1462            can_publish: true,
1463        })
1464    })
1465    .await;
1466
1467    let room = session
1468        .db()
1469        .await
1470        .create_room(session.user_id(), session.connection_id, &livekit_room)
1471        .await?;
1472
1473    response.send(proto::CreateRoomResponse {
1474        room: Some(room.clone()),
1475        live_kit_connection_info,
1476    })?;
1477
1478    update_user_contacts(session.user_id(), &session).await?;
1479    Ok(())
1480}
1481
1482/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1483async fn join_room(
1484    request: proto::JoinRoom,
1485    response: Response<proto::JoinRoom>,
1486    session: MessageContext,
1487) -> Result<()> {
1488    let room_id = RoomId::from_proto(request.id);
1489
1490    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1491
1492    if let Some(channel_id) = channel_id {
1493        return join_channel_internal(channel_id, Box::new(response), session).await;
1494    }
1495
1496    let joined_room = {
1497        let room = session
1498            .db()
1499            .await
1500            .join_room(room_id, session.user_id(), session.connection_id)
1501            .await?;
1502        room_updated(&room.room, &session.peer);
1503        room.into_inner()
1504    };
1505
1506    for connection_id in session
1507        .connection_pool()
1508        .await
1509        .user_connection_ids(session.user_id())
1510    {
1511        session
1512            .peer
1513            .send(
1514                connection_id,
1515                proto::CallCanceled {
1516                    room_id: room_id.to_proto(),
1517                },
1518            )
1519            .trace_err();
1520    }
1521
1522    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1523    {
1524        live_kit
1525            .room_token(
1526                &joined_room.room.livekit_room,
1527                &session.user_id().to_string(),
1528            )
1529            .trace_err()
1530            .map(|token| proto::LiveKitConnectionInfo {
1531                server_url: live_kit.url().into(),
1532                token,
1533                can_publish: true,
1534            })
1535    } else {
1536        None
1537    };
1538
1539    response.send(proto::JoinRoomResponse {
1540        room: Some(joined_room.room),
1541        channel_id: None,
1542        live_kit_connection_info,
1543    })?;
1544
1545    update_user_contacts(session.user_id(), &session).await?;
1546    Ok(())
1547}
1548
1549/// Rejoin room is used to reconnect to a room after connection errors.
1550async fn rejoin_room(
1551    request: proto::RejoinRoom,
1552    response: Response<proto::RejoinRoom>,
1553    session: MessageContext,
1554) -> Result<()> {
1555    let room;
1556    let channel;
1557    {
1558        let mut rejoined_room = session
1559            .db()
1560            .await
1561            .rejoin_room(request, session.user_id(), session.connection_id)
1562            .await?;
1563
1564        response.send(proto::RejoinRoomResponse {
1565            room: Some(rejoined_room.room.clone()),
1566            reshared_projects: rejoined_room
1567                .reshared_projects
1568                .iter()
1569                .map(|project| proto::ResharedProject {
1570                    id: project.id.to_proto(),
1571                    collaborators: project
1572                        .collaborators
1573                        .iter()
1574                        .map(|collaborator| collaborator.to_proto())
1575                        .collect(),
1576                })
1577                .collect(),
1578            rejoined_projects: rejoined_room
1579                .rejoined_projects
1580                .iter()
1581                .map(|rejoined_project| rejoined_project.to_proto())
1582                .collect(),
1583        })?;
1584        room_updated(&rejoined_room.room, &session.peer);
1585
1586        for project in &rejoined_room.reshared_projects {
1587            for collaborator in &project.collaborators {
1588                session
1589                    .peer
1590                    .send(
1591                        collaborator.connection_id,
1592                        proto::UpdateProjectCollaborator {
1593                            project_id: project.id.to_proto(),
1594                            old_peer_id: Some(project.old_connection_id.into()),
1595                            new_peer_id: Some(session.connection_id.into()),
1596                        },
1597                    )
1598                    .trace_err();
1599            }
1600
1601            broadcast(
1602                Some(session.connection_id),
1603                project
1604                    .collaborators
1605                    .iter()
1606                    .map(|collaborator| collaborator.connection_id),
1607                |connection_id| {
1608                    session.peer.forward_send(
1609                        session.connection_id,
1610                        connection_id,
1611                        proto::UpdateProject {
1612                            project_id: project.id.to_proto(),
1613                            worktrees: project.worktrees.clone(),
1614                        },
1615                    )
1616                },
1617            );
1618        }
1619
1620        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1621
1622        let rejoined_room = rejoined_room.into_inner();
1623
1624        room = rejoined_room.room;
1625        channel = rejoined_room.channel;
1626    }
1627
1628    if let Some(channel) = channel {
1629        channel_updated(
1630            &channel,
1631            &room,
1632            &session.peer,
1633            &*session.connection_pool().await,
1634        );
1635    }
1636
1637    update_user_contacts(session.user_id(), &session).await?;
1638    Ok(())
1639}
1640
1641fn notify_rejoined_projects(
1642    rejoined_projects: &mut Vec<RejoinedProject>,
1643    session: &Session,
1644) -> Result<()> {
1645    for project in rejoined_projects.iter() {
1646        for collaborator in &project.collaborators {
1647            session
1648                .peer
1649                .send(
1650                    collaborator.connection_id,
1651                    proto::UpdateProjectCollaborator {
1652                        project_id: project.id.to_proto(),
1653                        old_peer_id: Some(project.old_connection_id.into()),
1654                        new_peer_id: Some(session.connection_id.into()),
1655                    },
1656                )
1657                .trace_err();
1658        }
1659    }
1660
1661    for project in rejoined_projects {
1662        for worktree in mem::take(&mut project.worktrees) {
1663            // Stream this worktree's entries.
1664            let message = proto::UpdateWorktree {
1665                project_id: project.id.to_proto(),
1666                worktree_id: worktree.id,
1667                abs_path: worktree.abs_path.clone(),
1668                root_name: worktree.root_name,
1669                updated_entries: worktree.updated_entries,
1670                removed_entries: worktree.removed_entries,
1671                scan_id: worktree.scan_id,
1672                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1673                updated_repositories: worktree.updated_repositories,
1674                removed_repositories: worktree.removed_repositories,
1675            };
1676            for update in proto::split_worktree_update(message) {
1677                session.peer.send(session.connection_id, update)?;
1678            }
1679
1680            // Stream this worktree's diagnostics.
1681            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1682            if let Some(summary) = worktree_diagnostics.next() {
1683                let message = proto::UpdateDiagnosticSummary {
1684                    project_id: project.id.to_proto(),
1685                    worktree_id: worktree.id,
1686                    summary: Some(summary),
1687                    more_summaries: worktree_diagnostics.collect(),
1688                };
1689                session.peer.send(session.connection_id, message)?;
1690            }
1691
1692            for settings_file in worktree.settings_files {
1693                session.peer.send(
1694                    session.connection_id,
1695                    proto::UpdateWorktreeSettings {
1696                        project_id: project.id.to_proto(),
1697                        worktree_id: worktree.id,
1698                        path: settings_file.path,
1699                        content: Some(settings_file.content),
1700                        kind: Some(settings_file.kind.to_proto().into()),
1701                    },
1702                )?;
1703            }
1704        }
1705
1706        for repository in mem::take(&mut project.updated_repositories) {
1707            for update in split_repository_update(repository) {
1708                session.peer.send(session.connection_id, update)?;
1709            }
1710        }
1711
1712        for id in mem::take(&mut project.removed_repositories) {
1713            session.peer.send(
1714                session.connection_id,
1715                proto::RemoveRepository {
1716                    project_id: project.id.to_proto(),
1717                    id,
1718                },
1719            )?;
1720        }
1721    }
1722
1723    Ok(())
1724}
1725
1726/// leave room disconnects from the room.
1727async fn leave_room(
1728    _: proto::LeaveRoom,
1729    response: Response<proto::LeaveRoom>,
1730    session: MessageContext,
1731) -> Result<()> {
1732    leave_room_for_session(&session, session.connection_id).await?;
1733    response.send(proto::Ack {})?;
1734    Ok(())
1735}
1736
1737/// Updates the permissions of someone else in the room.
1738async fn set_room_participant_role(
1739    request: proto::SetRoomParticipantRole,
1740    response: Response<proto::SetRoomParticipantRole>,
1741    session: MessageContext,
1742) -> Result<()> {
1743    let user_id = UserId::from_proto(request.user_id);
1744    let role = ChannelRole::from(request.role());
1745
1746    let (livekit_room, can_publish) = {
1747        let room = session
1748            .db()
1749            .await
1750            .set_room_participant_role(
1751                session.user_id(),
1752                RoomId::from_proto(request.room_id),
1753                user_id,
1754                role,
1755            )
1756            .await?;
1757
1758        let livekit_room = room.livekit_room.clone();
1759        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1760        room_updated(&room, &session.peer);
1761        (livekit_room, can_publish)
1762    };
1763
1764    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1765        live_kit
1766            .update_participant(
1767                livekit_room.clone(),
1768                request.user_id.to_string(),
1769                livekit_api::proto::ParticipantPermission {
1770                    can_subscribe: true,
1771                    can_publish,
1772                    can_publish_data: can_publish,
1773                    hidden: false,
1774                    recorder: false,
1775                },
1776            )
1777            .await
1778            .trace_err();
1779    }
1780
1781    response.send(proto::Ack {})?;
1782    Ok(())
1783}
1784
1785/// Call someone else into the current room
1786async fn call(
1787    request: proto::Call,
1788    response: Response<proto::Call>,
1789    session: MessageContext,
1790) -> Result<()> {
1791    let room_id = RoomId::from_proto(request.room_id);
1792    let calling_user_id = session.user_id();
1793    let calling_connection_id = session.connection_id;
1794    let called_user_id = UserId::from_proto(request.called_user_id);
1795    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1796    if !session
1797        .db()
1798        .await
1799        .has_contact(calling_user_id, called_user_id)
1800        .await?
1801    {
1802        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1803    }
1804
1805    let incoming_call = {
1806        let (room, incoming_call) = &mut *session
1807            .db()
1808            .await
1809            .call(
1810                room_id,
1811                calling_user_id,
1812                calling_connection_id,
1813                called_user_id,
1814                initial_project_id,
1815            )
1816            .await?;
1817        room_updated(room, &session.peer);
1818        mem::take(incoming_call)
1819    };
1820    update_user_contacts(called_user_id, &session).await?;
1821
1822    let mut calls = session
1823        .connection_pool()
1824        .await
1825        .user_connection_ids(called_user_id)
1826        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1827        .collect::<FuturesUnordered<_>>();
1828
1829    while let Some(call_response) = calls.next().await {
1830        match call_response.as_ref() {
1831            Ok(_) => {
1832                response.send(proto::Ack {})?;
1833                return Ok(());
1834            }
1835            Err(_) => {
1836                call_response.trace_err();
1837            }
1838        }
1839    }
1840
1841    {
1842        let room = session
1843            .db()
1844            .await
1845            .call_failed(room_id, called_user_id)
1846            .await?;
1847        room_updated(&room, &session.peer);
1848    }
1849    update_user_contacts(called_user_id, &session).await?;
1850
1851    Err(anyhow!("failed to ring user"))?
1852}
1853
1854/// Cancel an outgoing call.
1855async fn cancel_call(
1856    request: proto::CancelCall,
1857    response: Response<proto::CancelCall>,
1858    session: MessageContext,
1859) -> Result<()> {
1860    let called_user_id = UserId::from_proto(request.called_user_id);
1861    let room_id = RoomId::from_proto(request.room_id);
1862    {
1863        let room = session
1864            .db()
1865            .await
1866            .cancel_call(room_id, session.connection_id, called_user_id)
1867            .await?;
1868        room_updated(&room, &session.peer);
1869    }
1870
1871    for connection_id in session
1872        .connection_pool()
1873        .await
1874        .user_connection_ids(called_user_id)
1875    {
1876        session
1877            .peer
1878            .send(
1879                connection_id,
1880                proto::CallCanceled {
1881                    room_id: room_id.to_proto(),
1882                },
1883            )
1884            .trace_err();
1885    }
1886    response.send(proto::Ack {})?;
1887
1888    update_user_contacts(called_user_id, &session).await?;
1889    Ok(())
1890}
1891
1892/// Decline an incoming call.
1893async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1894    let room_id = RoomId::from_proto(message.room_id);
1895    {
1896        let room = session
1897            .db()
1898            .await
1899            .decline_call(Some(room_id), session.user_id())
1900            .await?
1901            .context("declining call")?;
1902        room_updated(&room, &session.peer);
1903    }
1904
1905    for connection_id in session
1906        .connection_pool()
1907        .await
1908        .user_connection_ids(session.user_id())
1909    {
1910        session
1911            .peer
1912            .send(
1913                connection_id,
1914                proto::CallCanceled {
1915                    room_id: room_id.to_proto(),
1916                },
1917            )
1918            .trace_err();
1919    }
1920    update_user_contacts(session.user_id(), &session).await?;
1921    Ok(())
1922}
1923
1924/// Updates other participants in the room with your current location.
1925async fn update_participant_location(
1926    request: proto::UpdateParticipantLocation,
1927    response: Response<proto::UpdateParticipantLocation>,
1928    session: MessageContext,
1929) -> Result<()> {
1930    let room_id = RoomId::from_proto(request.room_id);
1931    let location = request.location.context("invalid location")?;
1932
1933    let db = session.db().await;
1934    let room = db
1935        .update_room_participant_location(room_id, session.connection_id, location)
1936        .await?;
1937
1938    room_updated(&room, &session.peer);
1939    response.send(proto::Ack {})?;
1940    Ok(())
1941}
1942
1943/// Share a project into the room.
1944async fn share_project(
1945    request: proto::ShareProject,
1946    response: Response<proto::ShareProject>,
1947    session: MessageContext,
1948) -> Result<()> {
1949    let (project_id, room) = &*session
1950        .db()
1951        .await
1952        .share_project(
1953            RoomId::from_proto(request.room_id),
1954            session.connection_id,
1955            &request.worktrees,
1956            request.is_ssh_project,
1957        )
1958        .await?;
1959    response.send(proto::ShareProjectResponse {
1960        project_id: project_id.to_proto(),
1961    })?;
1962    room_updated(room, &session.peer);
1963
1964    Ok(())
1965}
1966
1967/// Unshare a project from the room.
1968async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1969    let project_id = ProjectId::from_proto(message.project_id);
1970    unshare_project_internal(project_id, session.connection_id, &session).await
1971}
1972
1973async fn unshare_project_internal(
1974    project_id: ProjectId,
1975    connection_id: ConnectionId,
1976    session: &Session,
1977) -> Result<()> {
1978    let delete = {
1979        let room_guard = session
1980            .db()
1981            .await
1982            .unshare_project(project_id, connection_id)
1983            .await?;
1984
1985        let (delete, room, guest_connection_ids) = &*room_guard;
1986
1987        let message = proto::UnshareProject {
1988            project_id: project_id.to_proto(),
1989        };
1990
1991        broadcast(
1992            Some(connection_id),
1993            guest_connection_ids.iter().copied(),
1994            |conn_id| session.peer.send(conn_id, message.clone()),
1995        );
1996        if let Some(room) = room {
1997            room_updated(room, &session.peer);
1998        }
1999
2000        *delete
2001    };
2002
2003    if delete {
2004        let db = session.db().await;
2005        db.delete_project(project_id).await?;
2006    }
2007
2008    Ok(())
2009}
2010
2011/// Join someone elses shared project.
2012async fn join_project(
2013    request: proto::JoinProject,
2014    response: Response<proto::JoinProject>,
2015    session: MessageContext,
2016) -> Result<()> {
2017    let project_id = ProjectId::from_proto(request.project_id);
2018
2019    tracing::info!(%project_id, "join project");
2020
2021    let db = session.db().await;
2022    let (project, replica_id) = &mut *db
2023        .join_project(
2024            project_id,
2025            session.connection_id,
2026            session.user_id(),
2027            request.committer_name.clone(),
2028            request.committer_email.clone(),
2029        )
2030        .await?;
2031    drop(db);
2032    tracing::info!(%project_id, "join remote project");
2033    let collaborators = project
2034        .collaborators
2035        .iter()
2036        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2037        .map(|collaborator| collaborator.to_proto())
2038        .collect::<Vec<_>>();
2039    let project_id = project.id;
2040    let guest_user_id = session.user_id();
2041
2042    let worktrees = project
2043        .worktrees
2044        .iter()
2045        .map(|(id, worktree)| proto::WorktreeMetadata {
2046            id: *id,
2047            root_name: worktree.root_name.clone(),
2048            visible: worktree.visible,
2049            abs_path: worktree.abs_path.clone(),
2050        })
2051        .collect::<Vec<_>>();
2052
2053    let add_project_collaborator = proto::AddProjectCollaborator {
2054        project_id: project_id.to_proto(),
2055        collaborator: Some(proto::Collaborator {
2056            peer_id: Some(session.connection_id.into()),
2057            replica_id: replica_id.0 as u32,
2058            user_id: guest_user_id.to_proto(),
2059            is_host: false,
2060            committer_name: request.committer_name.clone(),
2061            committer_email: request.committer_email.clone(),
2062        }),
2063    };
2064
2065    for collaborator in &collaborators {
2066        session
2067            .peer
2068            .send(
2069                collaborator.peer_id.unwrap().into(),
2070                add_project_collaborator.clone(),
2071            )
2072            .trace_err();
2073    }
2074
2075    // First, we send the metadata associated with each worktree.
2076    let (language_servers, language_server_capabilities) = project
2077        .language_servers
2078        .clone()
2079        .into_iter()
2080        .map(|server| (server.server, server.capabilities))
2081        .unzip();
2082    response.send(proto::JoinProjectResponse {
2083        project_id: project.id.0 as u64,
2084        worktrees: worktrees.clone(),
2085        replica_id: replica_id.0 as u32,
2086        collaborators: collaborators.clone(),
2087        language_servers,
2088        language_server_capabilities,
2089        role: project.role.into(),
2090    })?;
2091
2092    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2093        // Stream this worktree's entries.
2094        let message = proto::UpdateWorktree {
2095            project_id: project_id.to_proto(),
2096            worktree_id,
2097            abs_path: worktree.abs_path.clone(),
2098            root_name: worktree.root_name,
2099            updated_entries: worktree.entries,
2100            removed_entries: Default::default(),
2101            scan_id: worktree.scan_id,
2102            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2103            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2104            removed_repositories: Default::default(),
2105        };
2106        for update in proto::split_worktree_update(message) {
2107            session.peer.send(session.connection_id, update.clone())?;
2108        }
2109
2110        // Stream this worktree's diagnostics.
2111        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2112        if let Some(summary) = worktree_diagnostics.next() {
2113            let message = proto::UpdateDiagnosticSummary {
2114                project_id: project.id.to_proto(),
2115                worktree_id: worktree.id,
2116                summary: Some(summary),
2117                more_summaries: worktree_diagnostics.collect(),
2118            };
2119            session.peer.send(session.connection_id, message)?;
2120        }
2121
2122        for settings_file in worktree.settings_files {
2123            session.peer.send(
2124                session.connection_id,
2125                proto::UpdateWorktreeSettings {
2126                    project_id: project_id.to_proto(),
2127                    worktree_id: worktree.id,
2128                    path: settings_file.path,
2129                    content: Some(settings_file.content),
2130                    kind: Some(settings_file.kind.to_proto() as i32),
2131                },
2132            )?;
2133        }
2134    }
2135
2136    for repository in mem::take(&mut project.repositories) {
2137        for update in split_repository_update(repository) {
2138            session.peer.send(session.connection_id, update)?;
2139        }
2140    }
2141
2142    for language_server in &project.language_servers {
2143        session.peer.send(
2144            session.connection_id,
2145            proto::UpdateLanguageServer {
2146                project_id: project_id.to_proto(),
2147                server_name: Some(language_server.server.name.clone()),
2148                language_server_id: language_server.server.id,
2149                variant: Some(
2150                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2151                        proto::LspDiskBasedDiagnosticsUpdated {},
2152                    ),
2153                ),
2154            },
2155        )?;
2156    }
2157
2158    Ok(())
2159}
2160
2161/// Leave someone elses shared project.
2162async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2163    let sender_id = session.connection_id;
2164    let project_id = ProjectId::from_proto(request.project_id);
2165    let db = session.db().await;
2166
2167    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2168    tracing::info!(
2169        %project_id,
2170        "leave project"
2171    );
2172
2173    project_left(project, &session);
2174    if let Some(room) = room {
2175        room_updated(room, &session.peer);
2176    }
2177
2178    Ok(())
2179}
2180
2181/// Updates other participants with changes to the project
2182async fn update_project(
2183    request: proto::UpdateProject,
2184    response: Response<proto::UpdateProject>,
2185    session: MessageContext,
2186) -> Result<()> {
2187    let project_id = ProjectId::from_proto(request.project_id);
2188    let (room, guest_connection_ids) = &*session
2189        .db()
2190        .await
2191        .update_project(project_id, session.connection_id, &request.worktrees)
2192        .await?;
2193    broadcast(
2194        Some(session.connection_id),
2195        guest_connection_ids.iter().copied(),
2196        |connection_id| {
2197            session
2198                .peer
2199                .forward_send(session.connection_id, connection_id, request.clone())
2200        },
2201    );
2202    if let Some(room) = room {
2203        room_updated(room, &session.peer);
2204    }
2205    response.send(proto::Ack {})?;
2206
2207    Ok(())
2208}
2209
2210/// Updates other participants with changes to the worktree
2211async fn update_worktree(
2212    request: proto::UpdateWorktree,
2213    response: Response<proto::UpdateWorktree>,
2214    session: MessageContext,
2215) -> Result<()> {
2216    let guest_connection_ids = session
2217        .db()
2218        .await
2219        .update_worktree(&request, session.connection_id)
2220        .await?;
2221
2222    broadcast(
2223        Some(session.connection_id),
2224        guest_connection_ids.iter().copied(),
2225        |connection_id| {
2226            session
2227                .peer
2228                .forward_send(session.connection_id, connection_id, request.clone())
2229        },
2230    );
2231    response.send(proto::Ack {})?;
2232    Ok(())
2233}
2234
2235async fn update_repository(
2236    request: proto::UpdateRepository,
2237    response: Response<proto::UpdateRepository>,
2238    session: MessageContext,
2239) -> Result<()> {
2240    let guest_connection_ids = session
2241        .db()
2242        .await
2243        .update_repository(&request, session.connection_id)
2244        .await?;
2245
2246    broadcast(
2247        Some(session.connection_id),
2248        guest_connection_ids.iter().copied(),
2249        |connection_id| {
2250            session
2251                .peer
2252                .forward_send(session.connection_id, connection_id, request.clone())
2253        },
2254    );
2255    response.send(proto::Ack {})?;
2256    Ok(())
2257}
2258
2259async fn remove_repository(
2260    request: proto::RemoveRepository,
2261    response: Response<proto::RemoveRepository>,
2262    session: MessageContext,
2263) -> Result<()> {
2264    let guest_connection_ids = session
2265        .db()
2266        .await
2267        .remove_repository(&request, session.connection_id)
2268        .await?;
2269
2270    broadcast(
2271        Some(session.connection_id),
2272        guest_connection_ids.iter().copied(),
2273        |connection_id| {
2274            session
2275                .peer
2276                .forward_send(session.connection_id, connection_id, request.clone())
2277        },
2278    );
2279    response.send(proto::Ack {})?;
2280    Ok(())
2281}
2282
2283/// Updates other participants with changes to the diagnostics
2284async fn update_diagnostic_summary(
2285    message: proto::UpdateDiagnosticSummary,
2286    session: MessageContext,
2287) -> Result<()> {
2288    let guest_connection_ids = session
2289        .db()
2290        .await
2291        .update_diagnostic_summary(&message, session.connection_id)
2292        .await?;
2293
2294    broadcast(
2295        Some(session.connection_id),
2296        guest_connection_ids.iter().copied(),
2297        |connection_id| {
2298            session
2299                .peer
2300                .forward_send(session.connection_id, connection_id, message.clone())
2301        },
2302    );
2303
2304    Ok(())
2305}
2306
2307/// Updates other participants with changes to the worktree settings
2308async fn update_worktree_settings(
2309    message: proto::UpdateWorktreeSettings,
2310    session: MessageContext,
2311) -> Result<()> {
2312    let guest_connection_ids = session
2313        .db()
2314        .await
2315        .update_worktree_settings(&message, session.connection_id)
2316        .await?;
2317
2318    broadcast(
2319        Some(session.connection_id),
2320        guest_connection_ids.iter().copied(),
2321        |connection_id| {
2322            session
2323                .peer
2324                .forward_send(session.connection_id, connection_id, message.clone())
2325        },
2326    );
2327
2328    Ok(())
2329}
2330
2331/// Notify other participants that a language server has started.
2332async fn start_language_server(
2333    request: proto::StartLanguageServer,
2334    session: MessageContext,
2335) -> Result<()> {
2336    let guest_connection_ids = session
2337        .db()
2338        .await
2339        .start_language_server(&request, session.connection_id)
2340        .await?;
2341
2342    broadcast(
2343        Some(session.connection_id),
2344        guest_connection_ids.iter().copied(),
2345        |connection_id| {
2346            session
2347                .peer
2348                .forward_send(session.connection_id, connection_id, request.clone())
2349        },
2350    );
2351    Ok(())
2352}
2353
2354/// Notify other participants that a language server has changed.
2355async fn update_language_server(
2356    request: proto::UpdateLanguageServer,
2357    session: MessageContext,
2358) -> Result<()> {
2359    let project_id = ProjectId::from_proto(request.project_id);
2360    let db = session.db().await;
2361
2362    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2363    {
2364        if let Some(capabilities) = update.capabilities.clone() {
2365            db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2366                .await?;
2367        }
2368    }
2369
2370    let project_connection_ids = db
2371        .project_connection_ids(project_id, session.connection_id, true)
2372        .await?;
2373    broadcast(
2374        Some(session.connection_id),
2375        project_connection_ids.iter().copied(),
2376        |connection_id| {
2377            session
2378                .peer
2379                .forward_send(session.connection_id, connection_id, request.clone())
2380        },
2381    );
2382    Ok(())
2383}
2384
2385/// forward a project request to the host. These requests should be read only
2386/// as guests are allowed to send them.
2387async fn forward_read_only_project_request<T>(
2388    request: T,
2389    response: Response<T>,
2390    session: MessageContext,
2391) -> Result<()>
2392where
2393    T: EntityMessage + RequestMessage,
2394{
2395    let project_id = ProjectId::from_proto(request.remote_entity_id());
2396    let host_connection_id = session
2397        .db()
2398        .await
2399        .host_for_read_only_project_request(project_id, session.connection_id)
2400        .await?;
2401    let payload = session.forward_request(host_connection_id, request).await?;
2402    response.send(payload)?;
2403    Ok(())
2404}
2405
2406/// forward a project request to the host. These requests are disallowed
2407/// for guests.
2408async fn forward_mutating_project_request<T>(
2409    request: T,
2410    response: Response<T>,
2411    session: MessageContext,
2412) -> Result<()>
2413where
2414    T: EntityMessage + RequestMessage,
2415{
2416    let project_id = ProjectId::from_proto(request.remote_entity_id());
2417
2418    let host_connection_id = session
2419        .db()
2420        .await
2421        .host_for_mutating_project_request(project_id, session.connection_id)
2422        .await?;
2423    let payload = session.forward_request(host_connection_id, request).await?;
2424    response.send(payload)?;
2425    Ok(())
2426}
2427
2428async fn multi_lsp_query(
2429    request: MultiLspQuery,
2430    response: Response<MultiLspQuery>,
2431    session: MessageContext,
2432) -> Result<()> {
2433    tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2434    tracing::info!("multi_lsp_query message received");
2435    forward_mutating_project_request(request, response, session).await
2436}
2437
2438/// Notify other participants that a new buffer has been created
2439async fn create_buffer_for_peer(
2440    request: proto::CreateBufferForPeer,
2441    session: MessageContext,
2442) -> Result<()> {
2443    session
2444        .db()
2445        .await
2446        .check_user_is_project_host(
2447            ProjectId::from_proto(request.project_id),
2448            session.connection_id,
2449        )
2450        .await?;
2451    let peer_id = request.peer_id.context("invalid peer id")?;
2452    session
2453        .peer
2454        .forward_send(session.connection_id, peer_id.into(), request)?;
2455    Ok(())
2456}
2457
2458/// Notify other participants that a buffer has been updated. This is
2459/// allowed for guests as long as the update is limited to selections.
2460async fn update_buffer(
2461    request: proto::UpdateBuffer,
2462    response: Response<proto::UpdateBuffer>,
2463    session: MessageContext,
2464) -> Result<()> {
2465    let project_id = ProjectId::from_proto(request.project_id);
2466    let mut capability = Capability::ReadOnly;
2467
2468    for op in request.operations.iter() {
2469        match op.variant {
2470            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2471            Some(_) => capability = Capability::ReadWrite,
2472        }
2473    }
2474
2475    let host = {
2476        let guard = session
2477            .db()
2478            .await
2479            .connections_for_buffer_update(project_id, session.connection_id, capability)
2480            .await?;
2481
2482        let (host, guests) = &*guard;
2483
2484        broadcast(
2485            Some(session.connection_id),
2486            guests.clone(),
2487            |connection_id| {
2488                session
2489                    .peer
2490                    .forward_send(session.connection_id, connection_id, request.clone())
2491            },
2492        );
2493
2494        *host
2495    };
2496
2497    if host != session.connection_id {
2498        session.forward_request(host, request.clone()).await?;
2499    }
2500
2501    response.send(proto::Ack {})?;
2502    Ok(())
2503}
2504
2505async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2506    let project_id = ProjectId::from_proto(message.project_id);
2507
2508    let operation = message.operation.as_ref().context("invalid operation")?;
2509    let capability = match operation.variant.as_ref() {
2510        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2511            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2512                match buffer_op.variant {
2513                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2514                        Capability::ReadOnly
2515                    }
2516                    _ => Capability::ReadWrite,
2517                }
2518            } else {
2519                Capability::ReadWrite
2520            }
2521        }
2522        Some(_) => Capability::ReadWrite,
2523        None => Capability::ReadOnly,
2524    };
2525
2526    let guard = session
2527        .db()
2528        .await
2529        .connections_for_buffer_update(project_id, session.connection_id, capability)
2530        .await?;
2531
2532    let (host, guests) = &*guard;
2533
2534    broadcast(
2535        Some(session.connection_id),
2536        guests.iter().chain([host]).copied(),
2537        |connection_id| {
2538            session
2539                .peer
2540                .forward_send(session.connection_id, connection_id, message.clone())
2541        },
2542    );
2543
2544    Ok(())
2545}
2546
2547/// Notify other participants that a project has been updated.
2548async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2549    request: T,
2550    session: MessageContext,
2551) -> Result<()> {
2552    let project_id = ProjectId::from_proto(request.remote_entity_id());
2553    let project_connection_ids = session
2554        .db()
2555        .await
2556        .project_connection_ids(project_id, session.connection_id, false)
2557        .await?;
2558
2559    broadcast(
2560        Some(session.connection_id),
2561        project_connection_ids.iter().copied(),
2562        |connection_id| {
2563            session
2564                .peer
2565                .forward_send(session.connection_id, connection_id, request.clone())
2566        },
2567    );
2568    Ok(())
2569}
2570
2571/// Start following another user in a call.
2572async fn follow(
2573    request: proto::Follow,
2574    response: Response<proto::Follow>,
2575    session: MessageContext,
2576) -> Result<()> {
2577    let room_id = RoomId::from_proto(request.room_id);
2578    let project_id = request.project_id.map(ProjectId::from_proto);
2579    let leader_id = request.leader_id.context("invalid leader id")?.into();
2580    let follower_id = session.connection_id;
2581
2582    session
2583        .db()
2584        .await
2585        .check_room_participants(room_id, leader_id, session.connection_id)
2586        .await?;
2587
2588    let response_payload = session.forward_request(leader_id, request).await?;
2589    response.send(response_payload)?;
2590
2591    if let Some(project_id) = project_id {
2592        let room = session
2593            .db()
2594            .await
2595            .follow(room_id, project_id, leader_id, follower_id)
2596            .await?;
2597        room_updated(&room, &session.peer);
2598    }
2599
2600    Ok(())
2601}
2602
2603/// Stop following another user in a call.
2604async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2605    let room_id = RoomId::from_proto(request.room_id);
2606    let project_id = request.project_id.map(ProjectId::from_proto);
2607    let leader_id = request.leader_id.context("invalid leader id")?.into();
2608    let follower_id = session.connection_id;
2609
2610    session
2611        .db()
2612        .await
2613        .check_room_participants(room_id, leader_id, session.connection_id)
2614        .await?;
2615
2616    session
2617        .peer
2618        .forward_send(session.connection_id, leader_id, request)?;
2619
2620    if let Some(project_id) = project_id {
2621        let room = session
2622            .db()
2623            .await
2624            .unfollow(room_id, project_id, leader_id, follower_id)
2625            .await?;
2626        room_updated(&room, &session.peer);
2627    }
2628
2629    Ok(())
2630}
2631
2632/// Notify everyone following you of your current location.
2633async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2634    let room_id = RoomId::from_proto(request.room_id);
2635    let database = session.db.lock().await;
2636
2637    let connection_ids = if let Some(project_id) = request.project_id {
2638        let project_id = ProjectId::from_proto(project_id);
2639        database
2640            .project_connection_ids(project_id, session.connection_id, true)
2641            .await?
2642    } else {
2643        database
2644            .room_connection_ids(room_id, session.connection_id)
2645            .await?
2646    };
2647
2648    // For now, don't send view update messages back to that view's current leader.
2649    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2650        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2651        _ => None,
2652    });
2653
2654    for connection_id in connection_ids.iter().cloned() {
2655        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2656            session
2657                .peer
2658                .forward_send(session.connection_id, connection_id, request.clone())?;
2659        }
2660    }
2661    Ok(())
2662}
2663
2664/// Get public data about users.
2665async fn get_users(
2666    request: proto::GetUsers,
2667    response: Response<proto::GetUsers>,
2668    session: MessageContext,
2669) -> Result<()> {
2670    let user_ids = request
2671        .user_ids
2672        .into_iter()
2673        .map(UserId::from_proto)
2674        .collect();
2675    let users = session
2676        .db()
2677        .await
2678        .get_users_by_ids(user_ids)
2679        .await?
2680        .into_iter()
2681        .map(|user| proto::User {
2682            id: user.id.to_proto(),
2683            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2684            github_login: user.github_login,
2685            name: user.name,
2686        })
2687        .collect();
2688    response.send(proto::UsersResponse { users })?;
2689    Ok(())
2690}
2691
2692/// Search for users (to invite) buy Github login
2693async fn fuzzy_search_users(
2694    request: proto::FuzzySearchUsers,
2695    response: Response<proto::FuzzySearchUsers>,
2696    session: MessageContext,
2697) -> Result<()> {
2698    let query = request.query;
2699    let users = match query.len() {
2700        0 => vec![],
2701        1 | 2 => session
2702            .db()
2703            .await
2704            .get_user_by_github_login(&query)
2705            .await?
2706            .into_iter()
2707            .collect(),
2708        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2709    };
2710    let users = users
2711        .into_iter()
2712        .filter(|user| user.id != session.user_id())
2713        .map(|user| proto::User {
2714            id: user.id.to_proto(),
2715            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2716            github_login: user.github_login,
2717            name: user.name,
2718        })
2719        .collect();
2720    response.send(proto::UsersResponse { users })?;
2721    Ok(())
2722}
2723
2724/// Send a contact request to another user.
2725async fn request_contact(
2726    request: proto::RequestContact,
2727    response: Response<proto::RequestContact>,
2728    session: MessageContext,
2729) -> Result<()> {
2730    let requester_id = session.user_id();
2731    let responder_id = UserId::from_proto(request.responder_id);
2732    if requester_id == responder_id {
2733        return Err(anyhow!("cannot add yourself as a contact"))?;
2734    }
2735
2736    let notifications = session
2737        .db()
2738        .await
2739        .send_contact_request(requester_id, responder_id)
2740        .await?;
2741
2742    // Update outgoing contact requests of requester
2743    let mut update = proto::UpdateContacts::default();
2744    update.outgoing_requests.push(responder_id.to_proto());
2745    for connection_id in session
2746        .connection_pool()
2747        .await
2748        .user_connection_ids(requester_id)
2749    {
2750        session.peer.send(connection_id, update.clone())?;
2751    }
2752
2753    // Update incoming contact requests of responder
2754    let mut update = proto::UpdateContacts::default();
2755    update
2756        .incoming_requests
2757        .push(proto::IncomingContactRequest {
2758            requester_id: requester_id.to_proto(),
2759        });
2760    let connection_pool = session.connection_pool().await;
2761    for connection_id in connection_pool.user_connection_ids(responder_id) {
2762        session.peer.send(connection_id, update.clone())?;
2763    }
2764
2765    send_notifications(&connection_pool, &session.peer, notifications);
2766
2767    response.send(proto::Ack {})?;
2768    Ok(())
2769}
2770
2771/// Accept or decline a contact request
2772async fn respond_to_contact_request(
2773    request: proto::RespondToContactRequest,
2774    response: Response<proto::RespondToContactRequest>,
2775    session: MessageContext,
2776) -> Result<()> {
2777    let responder_id = session.user_id();
2778    let requester_id = UserId::from_proto(request.requester_id);
2779    let db = session.db().await;
2780    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2781        db.dismiss_contact_notification(responder_id, requester_id)
2782            .await?;
2783    } else {
2784        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2785
2786        let notifications = db
2787            .respond_to_contact_request(responder_id, requester_id, accept)
2788            .await?;
2789        let requester_busy = db.is_user_busy(requester_id).await?;
2790        let responder_busy = db.is_user_busy(responder_id).await?;
2791
2792        let pool = session.connection_pool().await;
2793        // Update responder with new contact
2794        let mut update = proto::UpdateContacts::default();
2795        if accept {
2796            update
2797                .contacts
2798                .push(contact_for_user(requester_id, requester_busy, &pool));
2799        }
2800        update
2801            .remove_incoming_requests
2802            .push(requester_id.to_proto());
2803        for connection_id in pool.user_connection_ids(responder_id) {
2804            session.peer.send(connection_id, update.clone())?;
2805        }
2806
2807        // Update requester with new contact
2808        let mut update = proto::UpdateContacts::default();
2809        if accept {
2810            update
2811                .contacts
2812                .push(contact_for_user(responder_id, responder_busy, &pool));
2813        }
2814        update
2815            .remove_outgoing_requests
2816            .push(responder_id.to_proto());
2817
2818        for connection_id in pool.user_connection_ids(requester_id) {
2819            session.peer.send(connection_id, update.clone())?;
2820        }
2821
2822        send_notifications(&pool, &session.peer, notifications);
2823    }
2824
2825    response.send(proto::Ack {})?;
2826    Ok(())
2827}
2828
2829/// Remove a contact.
2830async fn remove_contact(
2831    request: proto::RemoveContact,
2832    response: Response<proto::RemoveContact>,
2833    session: MessageContext,
2834) -> Result<()> {
2835    let requester_id = session.user_id();
2836    let responder_id = UserId::from_proto(request.user_id);
2837    let db = session.db().await;
2838    let (contact_accepted, deleted_notification_id) =
2839        db.remove_contact(requester_id, responder_id).await?;
2840
2841    let pool = session.connection_pool().await;
2842    // Update outgoing contact requests of requester
2843    let mut update = proto::UpdateContacts::default();
2844    if contact_accepted {
2845        update.remove_contacts.push(responder_id.to_proto());
2846    } else {
2847        update
2848            .remove_outgoing_requests
2849            .push(responder_id.to_proto());
2850    }
2851    for connection_id in pool.user_connection_ids(requester_id) {
2852        session.peer.send(connection_id, update.clone())?;
2853    }
2854
2855    // Update incoming contact requests of responder
2856    let mut update = proto::UpdateContacts::default();
2857    if contact_accepted {
2858        update.remove_contacts.push(requester_id.to_proto());
2859    } else {
2860        update
2861            .remove_incoming_requests
2862            .push(requester_id.to_proto());
2863    }
2864    for connection_id in pool.user_connection_ids(responder_id) {
2865        session.peer.send(connection_id, update.clone())?;
2866        if let Some(notification_id) = deleted_notification_id {
2867            session.peer.send(
2868                connection_id,
2869                proto::DeleteNotification {
2870                    notification_id: notification_id.to_proto(),
2871                },
2872            )?;
2873        }
2874    }
2875
2876    response.send(proto::Ack {})?;
2877    Ok(())
2878}
2879
2880fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2881    version.0.minor() < 139
2882}
2883
2884async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2885    if is_staff {
2886        return Ok(proto::Plan::ZedPro);
2887    }
2888
2889    let subscription = db.get_active_billing_subscription(user_id).await?;
2890    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2891
2892    let plan = if let Some(subscription_kind) = subscription_kind {
2893        match subscription_kind {
2894            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2895            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2896            SubscriptionKind::ZedFree => proto::Plan::Free,
2897        }
2898    } else {
2899        proto::Plan::Free
2900    };
2901
2902    Ok(plan)
2903}
2904
2905async fn make_update_user_plan_message(
2906    user: &User,
2907    is_staff: bool,
2908    db: &Arc<Database>,
2909    llm_db: Option<Arc<LlmDatabase>>,
2910) -> Result<proto::UpdateUserPlan> {
2911    let feature_flags = db.get_user_flags(user.id).await?;
2912    let plan = current_plan(db, user.id, is_staff).await?;
2913    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2914    let billing_preferences = db.get_billing_preferences(user.id).await?;
2915
2916    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2917        let subscription = db.get_active_billing_subscription(user.id).await?;
2918
2919        let subscription_period =
2920            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2921
2922        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2923            llm_db
2924                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2925                .await?
2926        } else {
2927            None
2928        };
2929
2930        (subscription_period, usage)
2931    } else {
2932        (None, None)
2933    };
2934
2935    let bypass_account_age_check = feature_flags
2936        .iter()
2937        .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2938    let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2939        && !bypass_account_age_check
2940        && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2941
2942    Ok(proto::UpdateUserPlan {
2943        plan: plan.into(),
2944        trial_started_at: billing_customer
2945            .as_ref()
2946            .and_then(|billing_customer| billing_customer.trial_started_at)
2947            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2948        is_usage_based_billing_enabled: if is_staff {
2949            Some(true)
2950        } else {
2951            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2952        },
2953        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2954            proto::SubscriptionPeriod {
2955                started_at: started_at.timestamp() as u64,
2956                ended_at: ended_at.timestamp() as u64,
2957            }
2958        }),
2959        account_too_young: Some(account_too_young),
2960        has_overdue_invoices: billing_customer
2961            .map(|billing_customer| billing_customer.has_overdue_invoices),
2962        usage: Some(
2963            usage
2964                .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2965                .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2966        ),
2967    })
2968}
2969
2970fn model_requests_limit(
2971    plan: cloud_llm_client::Plan,
2972    feature_flags: &Vec<String>,
2973) -> cloud_llm_client::UsageLimit {
2974    match plan.model_requests_limit() {
2975        cloud_llm_client::UsageLimit::Limited(limit) => {
2976            let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2977                && feature_flags
2978                    .iter()
2979                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2980            {
2981                1_000
2982            } else {
2983                limit
2984            };
2985
2986            cloud_llm_client::UsageLimit::Limited(limit)
2987        }
2988        cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2989    }
2990}
2991
2992fn subscription_usage_to_proto(
2993    plan: proto::Plan,
2994    usage: crate::llm::db::subscription_usage::Model,
2995    feature_flags: &Vec<String>,
2996) -> proto::SubscriptionUsage {
2997    let plan = match plan {
2998        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2999        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
3000        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
3001    };
3002
3003    proto::SubscriptionUsage {
3004        model_requests_usage_amount: usage.model_requests as u32,
3005        model_requests_usage_limit: Some(proto::UsageLimit {
3006            variant: Some(match model_requests_limit(plan, feature_flags) {
3007                cloud_llm_client::UsageLimit::Limited(limit) => {
3008                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3009                        limit: limit as u32,
3010                    })
3011                }
3012                cloud_llm_client::UsageLimit::Unlimited => {
3013                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3014                }
3015            }),
3016        }),
3017        edit_predictions_usage_amount: usage.edit_predictions as u32,
3018        edit_predictions_usage_limit: Some(proto::UsageLimit {
3019            variant: Some(match plan.edit_predictions_limit() {
3020                cloud_llm_client::UsageLimit::Limited(limit) => {
3021                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3022                        limit: limit as u32,
3023                    })
3024                }
3025                cloud_llm_client::UsageLimit::Unlimited => {
3026                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3027                }
3028            }),
3029        }),
3030    }
3031}
3032
3033fn make_default_subscription_usage(
3034    plan: proto::Plan,
3035    feature_flags: &Vec<String>,
3036) -> proto::SubscriptionUsage {
3037    let plan = match plan {
3038        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
3039        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
3040        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
3041    };
3042
3043    proto::SubscriptionUsage {
3044        model_requests_usage_amount: 0,
3045        model_requests_usage_limit: Some(proto::UsageLimit {
3046            variant: Some(match model_requests_limit(plan, feature_flags) {
3047                cloud_llm_client::UsageLimit::Limited(limit) => {
3048                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3049                        limit: limit as u32,
3050                    })
3051                }
3052                cloud_llm_client::UsageLimit::Unlimited => {
3053                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3054                }
3055            }),
3056        }),
3057        edit_predictions_usage_amount: 0,
3058        edit_predictions_usage_limit: Some(proto::UsageLimit {
3059            variant: Some(match plan.edit_predictions_limit() {
3060                cloud_llm_client::UsageLimit::Limited(limit) => {
3061                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3062                        limit: limit as u32,
3063                    })
3064                }
3065                cloud_llm_client::UsageLimit::Unlimited => {
3066                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3067                }
3068            }),
3069        }),
3070    }
3071}
3072
3073async fn update_user_plan(session: &Session) -> Result<()> {
3074    let db = session.db().await;
3075
3076    let update_user_plan = make_update_user_plan_message(
3077        session.principal.user(),
3078        session.is_staff(),
3079        &db.0,
3080        session.app_state.llm_db.clone(),
3081    )
3082    .await?;
3083
3084    session
3085        .peer
3086        .send(session.connection_id, update_user_plan)
3087        .trace_err();
3088
3089    Ok(())
3090}
3091
3092async fn subscribe_to_channels(
3093    _: proto::SubscribeToChannels,
3094    session: MessageContext,
3095) -> Result<()> {
3096    subscribe_user_to_channels(session.user_id(), &session).await?;
3097    Ok(())
3098}
3099
3100async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3101    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3102    let mut pool = session.connection_pool().await;
3103    for membership in &channels_for_user.channel_memberships {
3104        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3105    }
3106    session.peer.send(
3107        session.connection_id,
3108        build_update_user_channels(&channels_for_user),
3109    )?;
3110    session.peer.send(
3111        session.connection_id,
3112        build_channels_update(channels_for_user),
3113    )?;
3114    Ok(())
3115}
3116
3117/// Creates a new channel.
3118async fn create_channel(
3119    request: proto::CreateChannel,
3120    response: Response<proto::CreateChannel>,
3121    session: MessageContext,
3122) -> Result<()> {
3123    let db = session.db().await;
3124
3125    let parent_id = request.parent_id.map(ChannelId::from_proto);
3126    let (channel, membership) = db
3127        .create_channel(&request.name, parent_id, session.user_id())
3128        .await?;
3129
3130    let root_id = channel.root_id();
3131    let channel = Channel::from_model(channel);
3132
3133    response.send(proto::CreateChannelResponse {
3134        channel: Some(channel.to_proto()),
3135        parent_id: request.parent_id,
3136    })?;
3137
3138    let mut connection_pool = session.connection_pool().await;
3139    if let Some(membership) = membership {
3140        connection_pool.subscribe_to_channel(
3141            membership.user_id,
3142            membership.channel_id,
3143            membership.role,
3144        );
3145        let update = proto::UpdateUserChannels {
3146            channel_memberships: vec![proto::ChannelMembership {
3147                channel_id: membership.channel_id.to_proto(),
3148                role: membership.role.into(),
3149            }],
3150            ..Default::default()
3151        };
3152        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3153            session.peer.send(connection_id, update.clone())?;
3154        }
3155    }
3156
3157    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3158        if !role.can_see_channel(channel.visibility) {
3159            continue;
3160        }
3161
3162        let update = proto::UpdateChannels {
3163            channels: vec![channel.to_proto()],
3164            ..Default::default()
3165        };
3166        session.peer.send(connection_id, update.clone())?;
3167    }
3168
3169    Ok(())
3170}
3171
3172/// Delete a channel
3173async fn delete_channel(
3174    request: proto::DeleteChannel,
3175    response: Response<proto::DeleteChannel>,
3176    session: MessageContext,
3177) -> Result<()> {
3178    let db = session.db().await;
3179
3180    let channel_id = request.channel_id;
3181    let (root_channel, removed_channels) = db
3182        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3183        .await?;
3184    response.send(proto::Ack {})?;
3185
3186    // Notify members of removed channels
3187    let mut update = proto::UpdateChannels::default();
3188    update
3189        .delete_channels
3190        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3191
3192    let connection_pool = session.connection_pool().await;
3193    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3194        session.peer.send(connection_id, update.clone())?;
3195    }
3196
3197    Ok(())
3198}
3199
3200/// Invite someone to join a channel.
3201async fn invite_channel_member(
3202    request: proto::InviteChannelMember,
3203    response: Response<proto::InviteChannelMember>,
3204    session: MessageContext,
3205) -> Result<()> {
3206    let db = session.db().await;
3207    let channel_id = ChannelId::from_proto(request.channel_id);
3208    let invitee_id = UserId::from_proto(request.user_id);
3209    let InviteMemberResult {
3210        channel,
3211        notifications,
3212    } = db
3213        .invite_channel_member(
3214            channel_id,
3215            invitee_id,
3216            session.user_id(),
3217            request.role().into(),
3218        )
3219        .await?;
3220
3221    let update = proto::UpdateChannels {
3222        channel_invitations: vec![channel.to_proto()],
3223        ..Default::default()
3224    };
3225
3226    let connection_pool = session.connection_pool().await;
3227    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3228        session.peer.send(connection_id, update.clone())?;
3229    }
3230
3231    send_notifications(&connection_pool, &session.peer, notifications);
3232
3233    response.send(proto::Ack {})?;
3234    Ok(())
3235}
3236
3237/// remove someone from a channel
3238async fn remove_channel_member(
3239    request: proto::RemoveChannelMember,
3240    response: Response<proto::RemoveChannelMember>,
3241    session: MessageContext,
3242) -> Result<()> {
3243    let db = session.db().await;
3244    let channel_id = ChannelId::from_proto(request.channel_id);
3245    let member_id = UserId::from_proto(request.user_id);
3246
3247    let RemoveChannelMemberResult {
3248        membership_update,
3249        notification_id,
3250    } = db
3251        .remove_channel_member(channel_id, member_id, session.user_id())
3252        .await?;
3253
3254    let mut connection_pool = session.connection_pool().await;
3255    notify_membership_updated(
3256        &mut connection_pool,
3257        membership_update,
3258        member_id,
3259        &session.peer,
3260    );
3261    for connection_id in connection_pool.user_connection_ids(member_id) {
3262        if let Some(notification_id) = notification_id {
3263            session
3264                .peer
3265                .send(
3266                    connection_id,
3267                    proto::DeleteNotification {
3268                        notification_id: notification_id.to_proto(),
3269                    },
3270                )
3271                .trace_err();
3272        }
3273    }
3274
3275    response.send(proto::Ack {})?;
3276    Ok(())
3277}
3278
3279/// Toggle the channel between public and private.
3280/// Care is taken to maintain the invariant that public channels only descend from public channels,
3281/// (though members-only channels can appear at any point in the hierarchy).
3282async fn set_channel_visibility(
3283    request: proto::SetChannelVisibility,
3284    response: Response<proto::SetChannelVisibility>,
3285    session: MessageContext,
3286) -> Result<()> {
3287    let db = session.db().await;
3288    let channel_id = ChannelId::from_proto(request.channel_id);
3289    let visibility = request.visibility().into();
3290
3291    let channel_model = db
3292        .set_channel_visibility(channel_id, visibility, session.user_id())
3293        .await?;
3294    let root_id = channel_model.root_id();
3295    let channel = Channel::from_model(channel_model);
3296
3297    let mut connection_pool = session.connection_pool().await;
3298    for (user_id, role) in connection_pool
3299        .channel_user_ids(root_id)
3300        .collect::<Vec<_>>()
3301        .into_iter()
3302    {
3303        let update = if role.can_see_channel(channel.visibility) {
3304            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3305            proto::UpdateChannels {
3306                channels: vec![channel.to_proto()],
3307                ..Default::default()
3308            }
3309        } else {
3310            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3311            proto::UpdateChannels {
3312                delete_channels: vec![channel.id.to_proto()],
3313                ..Default::default()
3314            }
3315        };
3316
3317        for connection_id in connection_pool.user_connection_ids(user_id) {
3318            session.peer.send(connection_id, update.clone())?;
3319        }
3320    }
3321
3322    response.send(proto::Ack {})?;
3323    Ok(())
3324}
3325
3326/// Alter the role for a user in the channel.
3327async fn set_channel_member_role(
3328    request: proto::SetChannelMemberRole,
3329    response: Response<proto::SetChannelMemberRole>,
3330    session: MessageContext,
3331) -> Result<()> {
3332    let db = session.db().await;
3333    let channel_id = ChannelId::from_proto(request.channel_id);
3334    let member_id = UserId::from_proto(request.user_id);
3335    let result = db
3336        .set_channel_member_role(
3337            channel_id,
3338            session.user_id(),
3339            member_id,
3340            request.role().into(),
3341        )
3342        .await?;
3343
3344    match result {
3345        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3346            let mut connection_pool = session.connection_pool().await;
3347            notify_membership_updated(
3348                &mut connection_pool,
3349                membership_update,
3350                member_id,
3351                &session.peer,
3352            )
3353        }
3354        db::SetMemberRoleResult::InviteUpdated(channel) => {
3355            let update = proto::UpdateChannels {
3356                channel_invitations: vec![channel.to_proto()],
3357                ..Default::default()
3358            };
3359
3360            for connection_id in session
3361                .connection_pool()
3362                .await
3363                .user_connection_ids(member_id)
3364            {
3365                session.peer.send(connection_id, update.clone())?;
3366            }
3367        }
3368    }
3369
3370    response.send(proto::Ack {})?;
3371    Ok(())
3372}
3373
3374/// Change the name of a channel
3375async fn rename_channel(
3376    request: proto::RenameChannel,
3377    response: Response<proto::RenameChannel>,
3378    session: MessageContext,
3379) -> Result<()> {
3380    let db = session.db().await;
3381    let channel_id = ChannelId::from_proto(request.channel_id);
3382    let channel_model = db
3383        .rename_channel(channel_id, session.user_id(), &request.name)
3384        .await?;
3385    let root_id = channel_model.root_id();
3386    let channel = Channel::from_model(channel_model);
3387
3388    response.send(proto::RenameChannelResponse {
3389        channel: Some(channel.to_proto()),
3390    })?;
3391
3392    let connection_pool = session.connection_pool().await;
3393    let update = proto::UpdateChannels {
3394        channels: vec![channel.to_proto()],
3395        ..Default::default()
3396    };
3397    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3398        if role.can_see_channel(channel.visibility) {
3399            session.peer.send(connection_id, update.clone())?;
3400        }
3401    }
3402
3403    Ok(())
3404}
3405
3406/// Move a channel to a new parent.
3407async fn move_channel(
3408    request: proto::MoveChannel,
3409    response: Response<proto::MoveChannel>,
3410    session: MessageContext,
3411) -> Result<()> {
3412    let channel_id = ChannelId::from_proto(request.channel_id);
3413    let to = ChannelId::from_proto(request.to);
3414
3415    let (root_id, channels) = session
3416        .db()
3417        .await
3418        .move_channel(channel_id, to, session.user_id())
3419        .await?;
3420
3421    let connection_pool = session.connection_pool().await;
3422    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3423        let channels = channels
3424            .iter()
3425            .filter_map(|channel| {
3426                if role.can_see_channel(channel.visibility) {
3427                    Some(channel.to_proto())
3428                } else {
3429                    None
3430                }
3431            })
3432            .collect::<Vec<_>>();
3433        if channels.is_empty() {
3434            continue;
3435        }
3436
3437        let update = proto::UpdateChannels {
3438            channels,
3439            ..Default::default()
3440        };
3441
3442        session.peer.send(connection_id, update.clone())?;
3443    }
3444
3445    response.send(Ack {})?;
3446    Ok(())
3447}
3448
3449async fn reorder_channel(
3450    request: proto::ReorderChannel,
3451    response: Response<proto::ReorderChannel>,
3452    session: MessageContext,
3453) -> Result<()> {
3454    let channel_id = ChannelId::from_proto(request.channel_id);
3455    let direction = request.direction();
3456
3457    let updated_channels = session
3458        .db()
3459        .await
3460        .reorder_channel(channel_id, direction, session.user_id())
3461        .await?;
3462
3463    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3464        let connection_pool = session.connection_pool().await;
3465        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3466            let channels = updated_channels
3467                .iter()
3468                .filter_map(|channel| {
3469                    if role.can_see_channel(channel.visibility) {
3470                        Some(channel.to_proto())
3471                    } else {
3472                        None
3473                    }
3474                })
3475                .collect::<Vec<_>>();
3476
3477            if channels.is_empty() {
3478                continue;
3479            }
3480
3481            let update = proto::UpdateChannels {
3482                channels,
3483                ..Default::default()
3484            };
3485
3486            session.peer.send(connection_id, update.clone())?;
3487        }
3488    }
3489
3490    response.send(Ack {})?;
3491    Ok(())
3492}
3493
3494/// Get the list of channel members
3495async fn get_channel_members(
3496    request: proto::GetChannelMembers,
3497    response: Response<proto::GetChannelMembers>,
3498    session: MessageContext,
3499) -> Result<()> {
3500    let db = session.db().await;
3501    let channel_id = ChannelId::from_proto(request.channel_id);
3502    let limit = if request.limit == 0 {
3503        u16::MAX as u64
3504    } else {
3505        request.limit
3506    };
3507    let (members, users) = db
3508        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3509        .await?;
3510    response.send(proto::GetChannelMembersResponse { members, users })?;
3511    Ok(())
3512}
3513
3514/// Accept or decline a channel invitation.
3515async fn respond_to_channel_invite(
3516    request: proto::RespondToChannelInvite,
3517    response: Response<proto::RespondToChannelInvite>,
3518    session: MessageContext,
3519) -> Result<()> {
3520    let db = session.db().await;
3521    let channel_id = ChannelId::from_proto(request.channel_id);
3522    let RespondToChannelInvite {
3523        membership_update,
3524        notifications,
3525    } = db
3526        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3527        .await?;
3528
3529    let mut connection_pool = session.connection_pool().await;
3530    if let Some(membership_update) = membership_update {
3531        notify_membership_updated(
3532            &mut connection_pool,
3533            membership_update,
3534            session.user_id(),
3535            &session.peer,
3536        );
3537    } else {
3538        let update = proto::UpdateChannels {
3539            remove_channel_invitations: vec![channel_id.to_proto()],
3540            ..Default::default()
3541        };
3542
3543        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3544            session.peer.send(connection_id, update.clone())?;
3545        }
3546    };
3547
3548    send_notifications(&connection_pool, &session.peer, notifications);
3549
3550    response.send(proto::Ack {})?;
3551
3552    Ok(())
3553}
3554
3555/// Join the channels' room
3556async fn join_channel(
3557    request: proto::JoinChannel,
3558    response: Response<proto::JoinChannel>,
3559    session: MessageContext,
3560) -> Result<()> {
3561    let channel_id = ChannelId::from_proto(request.channel_id);
3562    join_channel_internal(channel_id, Box::new(response), session).await
3563}
3564
3565trait JoinChannelInternalResponse {
3566    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3567}
3568impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3569    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3570        Response::<proto::JoinChannel>::send(self, result)
3571    }
3572}
3573impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3574    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3575        Response::<proto::JoinRoom>::send(self, result)
3576    }
3577}
3578
3579async fn join_channel_internal(
3580    channel_id: ChannelId,
3581    response: Box<impl JoinChannelInternalResponse>,
3582    session: MessageContext,
3583) -> Result<()> {
3584    let joined_room = {
3585        let mut db = session.db().await;
3586        // If zed quits without leaving the room, and the user re-opens zed before the
3587        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3588        // room they were in.
3589        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3590            tracing::info!(
3591                stale_connection_id = %connection,
3592                "cleaning up stale connection",
3593            );
3594            drop(db);
3595            leave_room_for_session(&session, connection).await?;
3596            db = session.db().await;
3597        }
3598
3599        let (joined_room, membership_updated, role) = db
3600            .join_channel(channel_id, session.user_id(), session.connection_id)
3601            .await?;
3602
3603        let live_kit_connection_info =
3604            session
3605                .app_state
3606                .livekit_client
3607                .as_ref()
3608                .and_then(|live_kit| {
3609                    let (can_publish, token) = if role == ChannelRole::Guest {
3610                        (
3611                            false,
3612                            live_kit
3613                                .guest_token(
3614                                    &joined_room.room.livekit_room,
3615                                    &session.user_id().to_string(),
3616                                )
3617                                .trace_err()?,
3618                        )
3619                    } else {
3620                        (
3621                            true,
3622                            live_kit
3623                                .room_token(
3624                                    &joined_room.room.livekit_room,
3625                                    &session.user_id().to_string(),
3626                                )
3627                                .trace_err()?,
3628                        )
3629                    };
3630
3631                    Some(LiveKitConnectionInfo {
3632                        server_url: live_kit.url().into(),
3633                        token,
3634                        can_publish,
3635                    })
3636                });
3637
3638        response.send(proto::JoinRoomResponse {
3639            room: Some(joined_room.room.clone()),
3640            channel_id: joined_room
3641                .channel
3642                .as_ref()
3643                .map(|channel| channel.id.to_proto()),
3644            live_kit_connection_info,
3645        })?;
3646
3647        let mut connection_pool = session.connection_pool().await;
3648        if let Some(membership_updated) = membership_updated {
3649            notify_membership_updated(
3650                &mut connection_pool,
3651                membership_updated,
3652                session.user_id(),
3653                &session.peer,
3654            );
3655        }
3656
3657        room_updated(&joined_room.room, &session.peer);
3658
3659        joined_room
3660    };
3661
3662    channel_updated(
3663        &joined_room.channel.context("channel not returned")?,
3664        &joined_room.room,
3665        &session.peer,
3666        &*session.connection_pool().await,
3667    );
3668
3669    update_user_contacts(session.user_id(), &session).await?;
3670    Ok(())
3671}
3672
3673/// Start editing the channel notes
3674async fn join_channel_buffer(
3675    request: proto::JoinChannelBuffer,
3676    response: Response<proto::JoinChannelBuffer>,
3677    session: MessageContext,
3678) -> Result<()> {
3679    let db = session.db().await;
3680    let channel_id = ChannelId::from_proto(request.channel_id);
3681
3682    let open_response = db
3683        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3684        .await?;
3685
3686    let collaborators = open_response.collaborators.clone();
3687    response.send(open_response)?;
3688
3689    let update = UpdateChannelBufferCollaborators {
3690        channel_id: channel_id.to_proto(),
3691        collaborators: collaborators.clone(),
3692    };
3693    channel_buffer_updated(
3694        session.connection_id,
3695        collaborators
3696            .iter()
3697            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3698        &update,
3699        &session.peer,
3700    );
3701
3702    Ok(())
3703}
3704
3705/// Edit the channel notes
3706async fn update_channel_buffer(
3707    request: proto::UpdateChannelBuffer,
3708    session: MessageContext,
3709) -> Result<()> {
3710    let db = session.db().await;
3711    let channel_id = ChannelId::from_proto(request.channel_id);
3712
3713    let (collaborators, epoch, version) = db
3714        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3715        .await?;
3716
3717    channel_buffer_updated(
3718        session.connection_id,
3719        collaborators.clone(),
3720        &proto::UpdateChannelBuffer {
3721            channel_id: channel_id.to_proto(),
3722            operations: request.operations,
3723        },
3724        &session.peer,
3725    );
3726
3727    let pool = &*session.connection_pool().await;
3728
3729    let non_collaborators =
3730        pool.channel_connection_ids(channel_id)
3731            .filter_map(|(connection_id, _)| {
3732                if collaborators.contains(&connection_id) {
3733                    None
3734                } else {
3735                    Some(connection_id)
3736                }
3737            });
3738
3739    broadcast(None, non_collaborators, |peer_id| {
3740        session.peer.send(
3741            peer_id,
3742            proto::UpdateChannels {
3743                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3744                    channel_id: channel_id.to_proto(),
3745                    epoch: epoch as u64,
3746                    version: version.clone(),
3747                }],
3748                ..Default::default()
3749            },
3750        )
3751    });
3752
3753    Ok(())
3754}
3755
3756/// Rejoin the channel notes after a connection blip
3757async fn rejoin_channel_buffers(
3758    request: proto::RejoinChannelBuffers,
3759    response: Response<proto::RejoinChannelBuffers>,
3760    session: MessageContext,
3761) -> Result<()> {
3762    let db = session.db().await;
3763    let buffers = db
3764        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3765        .await?;
3766
3767    for rejoined_buffer in &buffers {
3768        let collaborators_to_notify = rejoined_buffer
3769            .buffer
3770            .collaborators
3771            .iter()
3772            .filter_map(|c| Some(c.peer_id?.into()));
3773        channel_buffer_updated(
3774            session.connection_id,
3775            collaborators_to_notify,
3776            &proto::UpdateChannelBufferCollaborators {
3777                channel_id: rejoined_buffer.buffer.channel_id,
3778                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3779            },
3780            &session.peer,
3781        );
3782    }
3783
3784    response.send(proto::RejoinChannelBuffersResponse {
3785        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3786    })?;
3787
3788    Ok(())
3789}
3790
3791/// Stop editing the channel notes
3792async fn leave_channel_buffer(
3793    request: proto::LeaveChannelBuffer,
3794    response: Response<proto::LeaveChannelBuffer>,
3795    session: MessageContext,
3796) -> Result<()> {
3797    let db = session.db().await;
3798    let channel_id = ChannelId::from_proto(request.channel_id);
3799
3800    let left_buffer = db
3801        .leave_channel_buffer(channel_id, session.connection_id)
3802        .await?;
3803
3804    response.send(Ack {})?;
3805
3806    channel_buffer_updated(
3807        session.connection_id,
3808        left_buffer.connections,
3809        &proto::UpdateChannelBufferCollaborators {
3810            channel_id: channel_id.to_proto(),
3811            collaborators: left_buffer.collaborators,
3812        },
3813        &session.peer,
3814    );
3815
3816    Ok(())
3817}
3818
3819fn channel_buffer_updated<T: EnvelopedMessage>(
3820    sender_id: ConnectionId,
3821    collaborators: impl IntoIterator<Item = ConnectionId>,
3822    message: &T,
3823    peer: &Peer,
3824) {
3825    broadcast(Some(sender_id), collaborators, |peer_id| {
3826        peer.send(peer_id, message.clone())
3827    });
3828}
3829
3830fn send_notifications(
3831    connection_pool: &ConnectionPool,
3832    peer: &Peer,
3833    notifications: db::NotificationBatch,
3834) {
3835    for (user_id, notification) in notifications {
3836        for connection_id in connection_pool.user_connection_ids(user_id) {
3837            if let Err(error) = peer.send(
3838                connection_id,
3839                proto::AddNotification {
3840                    notification: Some(notification.clone()),
3841                },
3842            ) {
3843                tracing::error!(
3844                    "failed to send notification to {:?} {}",
3845                    connection_id,
3846                    error
3847                );
3848            }
3849        }
3850    }
3851}
3852
3853/// Send a message to the channel
3854async fn send_channel_message(
3855    request: proto::SendChannelMessage,
3856    response: Response<proto::SendChannelMessage>,
3857    session: MessageContext,
3858) -> Result<()> {
3859    // Validate the message body.
3860    let body = request.body.trim().to_string();
3861    if body.len() > MAX_MESSAGE_LEN {
3862        return Err(anyhow!("message is too long"))?;
3863    }
3864    if body.is_empty() {
3865        return Err(anyhow!("message can't be blank"))?;
3866    }
3867
3868    // TODO: adjust mentions if body is trimmed
3869
3870    let timestamp = OffsetDateTime::now_utc();
3871    let nonce = request.nonce.context("nonce can't be blank")?;
3872
3873    let channel_id = ChannelId::from_proto(request.channel_id);
3874    let CreatedChannelMessage {
3875        message_id,
3876        participant_connection_ids,
3877        notifications,
3878    } = session
3879        .db()
3880        .await
3881        .create_channel_message(
3882            channel_id,
3883            session.user_id(),
3884            &body,
3885            &request.mentions,
3886            timestamp,
3887            nonce.clone().into(),
3888            request.reply_to_message_id.map(MessageId::from_proto),
3889        )
3890        .await?;
3891
3892    let message = proto::ChannelMessage {
3893        sender_id: session.user_id().to_proto(),
3894        id: message_id.to_proto(),
3895        body,
3896        mentions: request.mentions,
3897        timestamp: timestamp.unix_timestamp() as u64,
3898        nonce: Some(nonce),
3899        reply_to_message_id: request.reply_to_message_id,
3900        edited_at: None,
3901    };
3902    broadcast(
3903        Some(session.connection_id),
3904        participant_connection_ids.clone(),
3905        |connection| {
3906            session.peer.send(
3907                connection,
3908                proto::ChannelMessageSent {
3909                    channel_id: channel_id.to_proto(),
3910                    message: Some(message.clone()),
3911                },
3912            )
3913        },
3914    );
3915    response.send(proto::SendChannelMessageResponse {
3916        message: Some(message),
3917    })?;
3918
3919    let pool = &*session.connection_pool().await;
3920    let non_participants =
3921        pool.channel_connection_ids(channel_id)
3922            .filter_map(|(connection_id, _)| {
3923                if participant_connection_ids.contains(&connection_id) {
3924                    None
3925                } else {
3926                    Some(connection_id)
3927                }
3928            });
3929    broadcast(None, non_participants, |peer_id| {
3930        session.peer.send(
3931            peer_id,
3932            proto::UpdateChannels {
3933                latest_channel_message_ids: vec![proto::ChannelMessageId {
3934                    channel_id: channel_id.to_proto(),
3935                    message_id: message_id.to_proto(),
3936                }],
3937                ..Default::default()
3938            },
3939        )
3940    });
3941    send_notifications(pool, &session.peer, notifications);
3942
3943    Ok(())
3944}
3945
3946/// Delete a channel message
3947async fn remove_channel_message(
3948    request: proto::RemoveChannelMessage,
3949    response: Response<proto::RemoveChannelMessage>,
3950    session: MessageContext,
3951) -> Result<()> {
3952    let channel_id = ChannelId::from_proto(request.channel_id);
3953    let message_id = MessageId::from_proto(request.message_id);
3954    let (connection_ids, existing_notification_ids) = session
3955        .db()
3956        .await
3957        .remove_channel_message(channel_id, message_id, session.user_id())
3958        .await?;
3959
3960    broadcast(
3961        Some(session.connection_id),
3962        connection_ids,
3963        move |connection| {
3964            session.peer.send(connection, request.clone())?;
3965
3966            for notification_id in &existing_notification_ids {
3967                session.peer.send(
3968                    connection,
3969                    proto::DeleteNotification {
3970                        notification_id: (*notification_id).to_proto(),
3971                    },
3972                )?;
3973            }
3974
3975            Ok(())
3976        },
3977    );
3978    response.send(proto::Ack {})?;
3979    Ok(())
3980}
3981
3982async fn update_channel_message(
3983    request: proto::UpdateChannelMessage,
3984    response: Response<proto::UpdateChannelMessage>,
3985    session: MessageContext,
3986) -> Result<()> {
3987    let channel_id = ChannelId::from_proto(request.channel_id);
3988    let message_id = MessageId::from_proto(request.message_id);
3989    let updated_at = OffsetDateTime::now_utc();
3990    let UpdatedChannelMessage {
3991        message_id,
3992        participant_connection_ids,
3993        notifications,
3994        reply_to_message_id,
3995        timestamp,
3996        deleted_mention_notification_ids,
3997        updated_mention_notifications,
3998    } = session
3999        .db()
4000        .await
4001        .update_channel_message(
4002            channel_id,
4003            message_id,
4004            session.user_id(),
4005            request.body.as_str(),
4006            &request.mentions,
4007            updated_at,
4008        )
4009        .await?;
4010
4011    let nonce = request.nonce.clone().context("nonce can't be blank")?;
4012
4013    let message = proto::ChannelMessage {
4014        sender_id: session.user_id().to_proto(),
4015        id: message_id.to_proto(),
4016        body: request.body.clone(),
4017        mentions: request.mentions.clone(),
4018        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4019        nonce: Some(nonce),
4020        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4021        edited_at: Some(updated_at.unix_timestamp() as u64),
4022    };
4023
4024    response.send(proto::Ack {})?;
4025
4026    let pool = &*session.connection_pool().await;
4027    broadcast(
4028        Some(session.connection_id),
4029        participant_connection_ids,
4030        |connection| {
4031            session.peer.send(
4032                connection,
4033                proto::ChannelMessageUpdate {
4034                    channel_id: channel_id.to_proto(),
4035                    message: Some(message.clone()),
4036                },
4037            )?;
4038
4039            for notification_id in &deleted_mention_notification_ids {
4040                session.peer.send(
4041                    connection,
4042                    proto::DeleteNotification {
4043                        notification_id: (*notification_id).to_proto(),
4044                    },
4045                )?;
4046            }
4047
4048            for notification in &updated_mention_notifications {
4049                session.peer.send(
4050                    connection,
4051                    proto::UpdateNotification {
4052                        notification: Some(notification.clone()),
4053                    },
4054                )?;
4055            }
4056
4057            Ok(())
4058        },
4059    );
4060
4061    send_notifications(pool, &session.peer, notifications);
4062
4063    Ok(())
4064}
4065
4066/// Mark a channel message as read
4067async fn acknowledge_channel_message(
4068    request: proto::AckChannelMessage,
4069    session: MessageContext,
4070) -> Result<()> {
4071    let channel_id = ChannelId::from_proto(request.channel_id);
4072    let message_id = MessageId::from_proto(request.message_id);
4073    let notifications = session
4074        .db()
4075        .await
4076        .observe_channel_message(channel_id, session.user_id(), message_id)
4077        .await?;
4078    send_notifications(
4079        &*session.connection_pool().await,
4080        &session.peer,
4081        notifications,
4082    );
4083    Ok(())
4084}
4085
4086/// Mark a buffer version as synced
4087async fn acknowledge_buffer_version(
4088    request: proto::AckBufferOperation,
4089    session: MessageContext,
4090) -> Result<()> {
4091    let buffer_id = BufferId::from_proto(request.buffer_id);
4092    session
4093        .db()
4094        .await
4095        .observe_buffer_version(
4096            buffer_id,
4097            session.user_id(),
4098            request.epoch as i32,
4099            &request.version,
4100        )
4101        .await?;
4102    Ok(())
4103}
4104
4105/// Get a Supermaven API key for the user
4106async fn get_supermaven_api_key(
4107    _request: proto::GetSupermavenApiKey,
4108    response: Response<proto::GetSupermavenApiKey>,
4109    session: MessageContext,
4110) -> Result<()> {
4111    let user_id: String = session.user_id().to_string();
4112    if !session.is_staff() {
4113        return Err(anyhow!("supermaven not enabled for this account"))?;
4114    }
4115
4116    let email = session.email().context("user must have an email")?;
4117
4118    let supermaven_admin_api = session
4119        .supermaven_client
4120        .as_ref()
4121        .context("supermaven not configured")?;
4122
4123    let result = supermaven_admin_api
4124        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4125        .await?;
4126
4127    response.send(proto::GetSupermavenApiKeyResponse {
4128        api_key: result.api_key,
4129    })?;
4130
4131    Ok(())
4132}
4133
4134/// Start receiving chat updates for a channel
4135async fn join_channel_chat(
4136    request: proto::JoinChannelChat,
4137    response: Response<proto::JoinChannelChat>,
4138    session: MessageContext,
4139) -> Result<()> {
4140    let channel_id = ChannelId::from_proto(request.channel_id);
4141
4142    let db = session.db().await;
4143    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4144        .await?;
4145    let messages = db
4146        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4147        .await?;
4148    response.send(proto::JoinChannelChatResponse {
4149        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4150        messages,
4151    })?;
4152    Ok(())
4153}
4154
4155/// Stop receiving chat updates for a channel
4156async fn leave_channel_chat(
4157    request: proto::LeaveChannelChat,
4158    session: MessageContext,
4159) -> Result<()> {
4160    let channel_id = ChannelId::from_proto(request.channel_id);
4161    session
4162        .db()
4163        .await
4164        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4165        .await?;
4166    Ok(())
4167}
4168
4169/// Retrieve the chat history for a channel
4170async fn get_channel_messages(
4171    request: proto::GetChannelMessages,
4172    response: Response<proto::GetChannelMessages>,
4173    session: MessageContext,
4174) -> Result<()> {
4175    let channel_id = ChannelId::from_proto(request.channel_id);
4176    let messages = session
4177        .db()
4178        .await
4179        .get_channel_messages(
4180            channel_id,
4181            session.user_id(),
4182            MESSAGE_COUNT_PER_PAGE,
4183            Some(MessageId::from_proto(request.before_message_id)),
4184        )
4185        .await?;
4186    response.send(proto::GetChannelMessagesResponse {
4187        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4188        messages,
4189    })?;
4190    Ok(())
4191}
4192
4193/// Retrieve specific chat messages
4194async fn get_channel_messages_by_id(
4195    request: proto::GetChannelMessagesById,
4196    response: Response<proto::GetChannelMessagesById>,
4197    session: MessageContext,
4198) -> Result<()> {
4199    let message_ids = request
4200        .message_ids
4201        .iter()
4202        .map(|id| MessageId::from_proto(*id))
4203        .collect::<Vec<_>>();
4204    let messages = session
4205        .db()
4206        .await
4207        .get_channel_messages_by_id(session.user_id(), &message_ids)
4208        .await?;
4209    response.send(proto::GetChannelMessagesResponse {
4210        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4211        messages,
4212    })?;
4213    Ok(())
4214}
4215
4216/// Retrieve the current users notifications
4217async fn get_notifications(
4218    request: proto::GetNotifications,
4219    response: Response<proto::GetNotifications>,
4220    session: MessageContext,
4221) -> Result<()> {
4222    let notifications = session
4223        .db()
4224        .await
4225        .get_notifications(
4226            session.user_id(),
4227            NOTIFICATION_COUNT_PER_PAGE,
4228            request.before_id.map(db::NotificationId::from_proto),
4229        )
4230        .await?;
4231    response.send(proto::GetNotificationsResponse {
4232        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4233        notifications,
4234    })?;
4235    Ok(())
4236}
4237
4238/// Mark notifications as read
4239async fn mark_notification_as_read(
4240    request: proto::MarkNotificationRead,
4241    response: Response<proto::MarkNotificationRead>,
4242    session: MessageContext,
4243) -> Result<()> {
4244    let database = &session.db().await;
4245    let notifications = database
4246        .mark_notification_as_read_by_id(
4247            session.user_id(),
4248            NotificationId::from_proto(request.notification_id),
4249        )
4250        .await?;
4251    send_notifications(
4252        &*session.connection_pool().await,
4253        &session.peer,
4254        notifications,
4255    );
4256    response.send(proto::Ack {})?;
4257    Ok(())
4258}
4259
4260/// Get the current users information
4261async fn get_private_user_info(
4262    _request: proto::GetPrivateUserInfo,
4263    response: Response<proto::GetPrivateUserInfo>,
4264    session: MessageContext,
4265) -> Result<()> {
4266    let db = session.db().await;
4267
4268    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4269    let user = db
4270        .get_user_by_id(session.user_id())
4271        .await?
4272        .context("user not found")?;
4273    let flags = db.get_user_flags(session.user_id()).await?;
4274
4275    response.send(proto::GetPrivateUserInfoResponse {
4276        metrics_id,
4277        staff: user.admin,
4278        flags,
4279        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4280    })?;
4281    Ok(())
4282}
4283
4284/// Accept the terms of service (tos) on behalf of the current user
4285async fn accept_terms_of_service(
4286    _request: proto::AcceptTermsOfService,
4287    response: Response<proto::AcceptTermsOfService>,
4288    session: MessageContext,
4289) -> Result<()> {
4290    let db = session.db().await;
4291
4292    let accepted_tos_at = Utc::now();
4293    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4294        .await?;
4295
4296    response.send(proto::AcceptTermsOfServiceResponse {
4297        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4298    })?;
4299
4300    // When the user accepts the terms of service, we want to refresh their LLM
4301    // token to grant access.
4302    session
4303        .peer
4304        .send(session.connection_id, proto::RefreshLlmToken {})?;
4305
4306    Ok(())
4307}
4308
4309async fn get_llm_api_token(
4310    _request: proto::GetLlmToken,
4311    response: Response<proto::GetLlmToken>,
4312    session: MessageContext,
4313) -> Result<()> {
4314    let db = session.db().await;
4315
4316    let flags = db.get_user_flags(session.user_id()).await?;
4317
4318    let user_id = session.user_id();
4319    let user = db
4320        .get_user_by_id(user_id)
4321        .await?
4322        .with_context(|| format!("user {user_id} not found"))?;
4323
4324    if user.accepted_tos_at.is_none() {
4325        Err(anyhow!("terms of service not accepted"))?
4326    }
4327
4328    let stripe_client = session
4329        .app_state
4330        .stripe_client
4331        .as_ref()
4332        .context("failed to retrieve Stripe client")?;
4333
4334    let stripe_billing = session
4335        .app_state
4336        .stripe_billing
4337        .as_ref()
4338        .context("failed to retrieve Stripe billing object")?;
4339
4340    let billing_customer = if let Some(billing_customer) =
4341        db.get_billing_customer_by_user_id(user.id).await?
4342    {
4343        billing_customer
4344    } else {
4345        let customer_id = stripe_billing
4346            .find_or_create_customer_by_email(user.email_address.as_deref())
4347            .await?;
4348
4349        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4350            .await?
4351            .context("billing customer not found")?
4352    };
4353
4354    let billing_subscription =
4355        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4356            billing_subscription
4357        } else {
4358            let stripe_customer_id =
4359                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4360
4361            let stripe_subscription = stripe_billing
4362                .subscribe_to_zed_free(stripe_customer_id)
4363                .await?;
4364
4365            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4366                billing_customer_id: billing_customer.id,
4367                kind: Some(SubscriptionKind::ZedFree),
4368                stripe_subscription_id: stripe_subscription.id.to_string(),
4369                stripe_subscription_status: stripe_subscription.status.into(),
4370                stripe_cancellation_reason: None,
4371                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4372                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4373            })
4374            .await?
4375        };
4376
4377    let billing_preferences = db.get_billing_preferences(user.id).await?;
4378
4379    let token = LlmTokenClaims::create(
4380        &user,
4381        session.is_staff(),
4382        billing_customer,
4383        billing_preferences,
4384        &flags,
4385        billing_subscription,
4386        session.system_id.clone(),
4387        &session.app_state.config,
4388    )?;
4389    response.send(proto::GetLlmTokenResponse { token })?;
4390    Ok(())
4391}
4392
4393fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4394    let message = match message {
4395        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4396        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4397        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4398        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4399        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4400            code: frame.code.into(),
4401            reason: frame.reason.as_str().to_owned().into(),
4402        })),
4403        // We should never receive a frame while reading the message, according
4404        // to the `tungstenite` maintainers:
4405        //
4406        // > It cannot occur when you read messages from the WebSocket, but it
4407        // > can be used when you want to send the raw frames (e.g. you want to
4408        // > send the frames to the WebSocket without composing the full message first).
4409        // >
4410        // > — https://github.com/snapview/tungstenite-rs/issues/268
4411        TungsteniteMessage::Frame(_) => {
4412            bail!("received an unexpected frame while reading the message")
4413        }
4414    };
4415
4416    Ok(message)
4417}
4418
4419fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4420    match message {
4421        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4422        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4423        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4424        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4425        AxumMessage::Close(frame) => {
4426            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4427                code: frame.code.into(),
4428                reason: frame.reason.as_ref().into(),
4429            }))
4430        }
4431    }
4432}
4433
4434fn notify_membership_updated(
4435    connection_pool: &mut ConnectionPool,
4436    result: MembershipUpdated,
4437    user_id: UserId,
4438    peer: &Peer,
4439) {
4440    for membership in &result.new_channels.channel_memberships {
4441        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4442    }
4443    for channel_id in &result.removed_channels {
4444        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4445    }
4446
4447    let user_channels_update = proto::UpdateUserChannels {
4448        channel_memberships: result
4449            .new_channels
4450            .channel_memberships
4451            .iter()
4452            .map(|cm| proto::ChannelMembership {
4453                channel_id: cm.channel_id.to_proto(),
4454                role: cm.role.into(),
4455            })
4456            .collect(),
4457        ..Default::default()
4458    };
4459
4460    let mut update = build_channels_update(result.new_channels);
4461    update.delete_channels = result
4462        .removed_channels
4463        .into_iter()
4464        .map(|id| id.to_proto())
4465        .collect();
4466    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4467
4468    for connection_id in connection_pool.user_connection_ids(user_id) {
4469        peer.send(connection_id, user_channels_update.clone())
4470            .trace_err();
4471        peer.send(connection_id, update.clone()).trace_err();
4472    }
4473}
4474
4475fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4476    proto::UpdateUserChannels {
4477        channel_memberships: channels
4478            .channel_memberships
4479            .iter()
4480            .map(|m| proto::ChannelMembership {
4481                channel_id: m.channel_id.to_proto(),
4482                role: m.role.into(),
4483            })
4484            .collect(),
4485        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4486        observed_channel_message_id: channels.observed_channel_messages.clone(),
4487    }
4488}
4489
4490fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4491    let mut update = proto::UpdateChannels::default();
4492
4493    for channel in channels.channels {
4494        update.channels.push(channel.to_proto());
4495    }
4496
4497    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4498    update.latest_channel_message_ids = channels.latest_channel_messages;
4499
4500    for (channel_id, participants) in channels.channel_participants {
4501        update
4502            .channel_participants
4503            .push(proto::ChannelParticipants {
4504                channel_id: channel_id.to_proto(),
4505                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4506            });
4507    }
4508
4509    for channel in channels.invited_channels {
4510        update.channel_invitations.push(channel.to_proto());
4511    }
4512
4513    update
4514}
4515
4516fn build_initial_contacts_update(
4517    contacts: Vec<db::Contact>,
4518    pool: &ConnectionPool,
4519) -> proto::UpdateContacts {
4520    let mut update = proto::UpdateContacts::default();
4521
4522    for contact in contacts {
4523        match contact {
4524            db::Contact::Accepted { user_id, busy } => {
4525                update.contacts.push(contact_for_user(user_id, busy, pool));
4526            }
4527            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4528            db::Contact::Incoming { user_id } => {
4529                update
4530                    .incoming_requests
4531                    .push(proto::IncomingContactRequest {
4532                        requester_id: user_id.to_proto(),
4533                    })
4534            }
4535        }
4536    }
4537
4538    update
4539}
4540
4541fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4542    proto::Contact {
4543        user_id: user_id.to_proto(),
4544        online: pool.is_user_online(user_id),
4545        busy,
4546    }
4547}
4548
4549fn room_updated(room: &proto::Room, peer: &Peer) {
4550    broadcast(
4551        None,
4552        room.participants
4553            .iter()
4554            .filter_map(|participant| Some(participant.peer_id?.into())),
4555        |peer_id| {
4556            peer.send(
4557                peer_id,
4558                proto::RoomUpdated {
4559                    room: Some(room.clone()),
4560                },
4561            )
4562        },
4563    );
4564}
4565
4566fn channel_updated(
4567    channel: &db::channel::Model,
4568    room: &proto::Room,
4569    peer: &Peer,
4570    pool: &ConnectionPool,
4571) {
4572    let participants = room
4573        .participants
4574        .iter()
4575        .map(|p| p.user_id)
4576        .collect::<Vec<_>>();
4577
4578    broadcast(
4579        None,
4580        pool.channel_connection_ids(channel.root_id())
4581            .filter_map(|(channel_id, role)| {
4582                role.can_see_channel(channel.visibility)
4583                    .then_some(channel_id)
4584            }),
4585        |peer_id| {
4586            peer.send(
4587                peer_id,
4588                proto::UpdateChannels {
4589                    channel_participants: vec![proto::ChannelParticipants {
4590                        channel_id: channel.id.to_proto(),
4591                        participant_user_ids: participants.clone(),
4592                    }],
4593                    ..Default::default()
4594                },
4595            )
4596        },
4597    );
4598}
4599
4600async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4601    let db = session.db().await;
4602
4603    let contacts = db.get_contacts(user_id).await?;
4604    let busy = db.is_user_busy(user_id).await?;
4605
4606    let pool = session.connection_pool().await;
4607    let updated_contact = contact_for_user(user_id, busy, &pool);
4608    for contact in contacts {
4609        if let db::Contact::Accepted {
4610            user_id: contact_user_id,
4611            ..
4612        } = contact
4613        {
4614            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4615                session
4616                    .peer
4617                    .send(
4618                        contact_conn_id,
4619                        proto::UpdateContacts {
4620                            contacts: vec![updated_contact.clone()],
4621                            remove_contacts: Default::default(),
4622                            incoming_requests: Default::default(),
4623                            remove_incoming_requests: Default::default(),
4624                            outgoing_requests: Default::default(),
4625                            remove_outgoing_requests: Default::default(),
4626                        },
4627                    )
4628                    .trace_err();
4629            }
4630        }
4631    }
4632    Ok(())
4633}
4634
4635async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4636    let mut contacts_to_update = HashSet::default();
4637
4638    let room_id;
4639    let canceled_calls_to_user_ids;
4640    let livekit_room;
4641    let delete_livekit_room;
4642    let room;
4643    let channel;
4644
4645    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4646        contacts_to_update.insert(session.user_id());
4647
4648        for project in left_room.left_projects.values() {
4649            project_left(project, session);
4650        }
4651
4652        room_id = RoomId::from_proto(left_room.room.id);
4653        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4654        livekit_room = mem::take(&mut left_room.room.livekit_room);
4655        delete_livekit_room = left_room.deleted;
4656        room = mem::take(&mut left_room.room);
4657        channel = mem::take(&mut left_room.channel);
4658
4659        room_updated(&room, &session.peer);
4660    } else {
4661        return Ok(());
4662    }
4663
4664    if let Some(channel) = channel {
4665        channel_updated(
4666            &channel,
4667            &room,
4668            &session.peer,
4669            &*session.connection_pool().await,
4670        );
4671    }
4672
4673    {
4674        let pool = session.connection_pool().await;
4675        for canceled_user_id in canceled_calls_to_user_ids {
4676            for connection_id in pool.user_connection_ids(canceled_user_id) {
4677                session
4678                    .peer
4679                    .send(
4680                        connection_id,
4681                        proto::CallCanceled {
4682                            room_id: room_id.to_proto(),
4683                        },
4684                    )
4685                    .trace_err();
4686            }
4687            contacts_to_update.insert(canceled_user_id);
4688        }
4689    }
4690
4691    for contact_user_id in contacts_to_update {
4692        update_user_contacts(contact_user_id, session).await?;
4693    }
4694
4695    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4696        live_kit
4697            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4698            .await
4699            .trace_err();
4700
4701        if delete_livekit_room {
4702            live_kit.delete_room(livekit_room).await.trace_err();
4703        }
4704    }
4705
4706    Ok(())
4707}
4708
4709async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4710    let left_channel_buffers = session
4711        .db()
4712        .await
4713        .leave_channel_buffers(session.connection_id)
4714        .await?;
4715
4716    for left_buffer in left_channel_buffers {
4717        channel_buffer_updated(
4718            session.connection_id,
4719            left_buffer.connections,
4720            &proto::UpdateChannelBufferCollaborators {
4721                channel_id: left_buffer.channel_id.to_proto(),
4722                collaborators: left_buffer.collaborators,
4723            },
4724            &session.peer,
4725        );
4726    }
4727
4728    Ok(())
4729}
4730
4731fn project_left(project: &db::LeftProject, session: &Session) {
4732    for connection_id in &project.connection_ids {
4733        if project.should_unshare {
4734            session
4735                .peer
4736                .send(
4737                    *connection_id,
4738                    proto::UnshareProject {
4739                        project_id: project.id.to_proto(),
4740                    },
4741                )
4742                .trace_err();
4743        } else {
4744            session
4745                .peer
4746                .send(
4747                    *connection_id,
4748                    proto::RemoveProjectCollaborator {
4749                        project_id: project.id.to_proto(),
4750                        peer_id: Some(session.connection_id.into()),
4751                    },
4752                )
4753                .trace_err();
4754        }
4755    }
4756}
4757
4758pub trait ResultExt {
4759    type Ok;
4760
4761    fn trace_err(self) -> Option<Self::Ok>;
4762}
4763
4764impl<T, E> ResultExt for Result<T, E>
4765where
4766    E: std::fmt::Debug,
4767{
4768    type Ok = T;
4769
4770    #[track_caller]
4771    fn trace_err(self) -> Option<T> {
4772        match self {
4773            Ok(value) => Some(value),
4774            Err(error) => {
4775                tracing::error!("{:?}", error);
4776                None
4777            }
4778        }
4779    }
4780}