rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{
   8    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
   9    MIN_ACCOUNT_AGE_FOR_LLM_USE,
  10};
  11use crate::stripe_client::StripeCustomerId;
  12use crate::{
  13    AppState, Error, Result, auth,
  14    db::{
  15        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  16        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  17        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  18        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  19    },
  20    executor::Executor,
  21};
  22use anyhow::{Context as _, anyhow, bail};
  23use async_tungstenite::tungstenite::{
  24    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  25};
  26use axum::headers::UserAgent;
  27use axum::{
  28    Extension, Router, TypedHeader,
  29    body::Body,
  30    extract::{
  31        ConnectInfo, WebSocketUpgrade,
  32        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  33    },
  34    headers::{Header, HeaderName},
  35    http::StatusCode,
  36    middleware,
  37    response::IntoResponse,
  38    routing::get,
  39};
  40use chrono::Utc;
  41use collections::{HashMap, HashSet};
  42pub use connection_pool::{ConnectionPool, ZedVersion};
  43use core::fmt::{self, Debug, Formatter};
  44use futures::TryFutureExt as _;
  45use reqwest_client::ReqwestClient;
  46use rpc::proto::{MultiLspQuery, split_repository_update};
  47use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  48use tracing::Span;
  49
  50use futures::{
  51    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  52    stream::FuturesUnordered,
  53};
  54use prometheus::{IntGauge, register_int_gauge};
  55use rpc::{
  56    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  57    proto::{
  58        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  59        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  60    },
  61};
  62use semantic_version::SemanticVersion;
  63use serde::{Serialize, Serializer};
  64use std::{
  65    any::TypeId,
  66    future::Future,
  67    marker::PhantomData,
  68    mem,
  69    net::SocketAddr,
  70    ops::{Deref, DerefMut},
  71    rc::Rc,
  72    sync::{
  73        Arc, OnceLock,
  74        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  75    },
  76    time::{Duration, Instant},
  77};
  78use time::OffsetDateTime;
  79use tokio::sync::{Semaphore, watch};
  80use tower::ServiceBuilder;
  81use tracing::{
  82    Instrument,
  83    field::{self},
  84    info_span, instrument,
  85};
  86
  87pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  88
  89// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  90pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  91
  92const MESSAGE_COUNT_PER_PAGE: usize = 100;
  93const MAX_MESSAGE_LEN: usize = 1024;
  94const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  95const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  96
  97static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  98
  99const TOTAL_DURATION_MS: &str = "total_duration_ms";
 100const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
 101const QUEUE_DURATION_MS: &str = "queue_duration_ms";
 102const HOST_WAITING_MS: &str = "host_waiting_ms";
 103
 104type MessageHandler =
 105    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
 106
 107pub struct ConnectionGuard;
 108
 109impl ConnectionGuard {
 110    pub fn try_acquire() -> Result<Self, ()> {
 111        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 112        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 113            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 114            tracing::error!(
 115                "too many concurrent connections: {}",
 116                current_connections + 1
 117            );
 118            return Err(());
 119        }
 120        Ok(ConnectionGuard)
 121    }
 122}
 123
 124impl Drop for ConnectionGuard {
 125    fn drop(&mut self) {
 126        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 127    }
 128}
 129
 130struct Response<R> {
 131    peer: Arc<Peer>,
 132    receipt: Receipt<R>,
 133    responded: Arc<AtomicBool>,
 134}
 135
 136impl<R: RequestMessage> Response<R> {
 137    fn send(self, payload: R::Response) -> Result<()> {
 138        self.responded.store(true, SeqCst);
 139        self.peer.respond(self.receipt, payload)?;
 140        Ok(())
 141    }
 142}
 143
 144#[derive(Clone, Debug)]
 145pub enum Principal {
 146    User(User),
 147    Impersonated { user: User, admin: User },
 148}
 149
 150impl Principal {
 151    fn user(&self) -> &User {
 152        match self {
 153            Principal::User(user) => user,
 154            Principal::Impersonated { user, .. } => user,
 155        }
 156    }
 157
 158    fn update_span(&self, span: &tracing::Span) {
 159        match &self {
 160            Principal::User(user) => {
 161                span.record("user_id", user.id.0);
 162                span.record("login", &user.github_login);
 163            }
 164            Principal::Impersonated { user, admin } => {
 165                span.record("user_id", user.id.0);
 166                span.record("login", &user.github_login);
 167                span.record("impersonator", &admin.github_login);
 168            }
 169        }
 170    }
 171}
 172
 173#[derive(Clone)]
 174struct MessageContext {
 175    session: Session,
 176    span: tracing::Span,
 177}
 178
 179impl Deref for MessageContext {
 180    type Target = Session;
 181
 182    fn deref(&self) -> &Self::Target {
 183        &self.session
 184    }
 185}
 186
 187impl MessageContext {
 188    pub fn forward_request<T: RequestMessage>(
 189        &self,
 190        receiver_id: ConnectionId,
 191        request: T,
 192    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 193        let request_start_time = Instant::now();
 194        let span = self.span.clone();
 195        tracing::info!("start forwarding request");
 196        self.peer
 197            .forward_request(self.connection_id, receiver_id, request)
 198            .inspect(move |_| {
 199                span.record(
 200                    HOST_WAITING_MS,
 201                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 202                );
 203            })
 204            .inspect_err(|_| tracing::error!("error forwarding request"))
 205            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 206    }
 207}
 208
 209#[derive(Clone)]
 210struct Session {
 211    principal: Principal,
 212    connection_id: ConnectionId,
 213    db: Arc<tokio::sync::Mutex<DbHandle>>,
 214    peer: Arc<Peer>,
 215    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 216    app_state: Arc<AppState>,
 217    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 218    /// The GeoIP country code for the user.
 219    #[allow(unused)]
 220    geoip_country_code: Option<String>,
 221    system_id: Option<String>,
 222    _executor: Executor,
 223}
 224
 225impl Session {
 226    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 227        #[cfg(test)]
 228        tokio::task::yield_now().await;
 229        let guard = self.db.lock().await;
 230        #[cfg(test)]
 231        tokio::task::yield_now().await;
 232        guard
 233    }
 234
 235    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 236        #[cfg(test)]
 237        tokio::task::yield_now().await;
 238        let guard = self.connection_pool.lock();
 239        ConnectionPoolGuard {
 240            guard,
 241            _not_send: PhantomData,
 242        }
 243    }
 244
 245    fn is_staff(&self) -> bool {
 246        match &self.principal {
 247            Principal::User(user) => user.admin,
 248            Principal::Impersonated { .. } => true,
 249        }
 250    }
 251
 252    fn user_id(&self) -> UserId {
 253        match &self.principal {
 254            Principal::User(user) => user.id,
 255            Principal::Impersonated { user, .. } => user.id,
 256        }
 257    }
 258
 259    pub fn email(&self) -> Option<String> {
 260        match &self.principal {
 261            Principal::User(user) => user.email_address.clone(),
 262            Principal::Impersonated { user, .. } => user.email_address.clone(),
 263        }
 264    }
 265}
 266
 267impl Debug for Session {
 268    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 269        let mut result = f.debug_struct("Session");
 270        match &self.principal {
 271            Principal::User(user) => {
 272                result.field("user", &user.github_login);
 273            }
 274            Principal::Impersonated { user, admin } => {
 275                result.field("user", &user.github_login);
 276                result.field("impersonator", &admin.github_login);
 277            }
 278        }
 279        result.field("connection_id", &self.connection_id).finish()
 280    }
 281}
 282
 283struct DbHandle(Arc<Database>);
 284
 285impl Deref for DbHandle {
 286    type Target = Database;
 287
 288    fn deref(&self) -> &Self::Target {
 289        self.0.as_ref()
 290    }
 291}
 292
 293pub struct Server {
 294    id: parking_lot::Mutex<ServerId>,
 295    peer: Arc<Peer>,
 296    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 297    app_state: Arc<AppState>,
 298    handlers: HashMap<TypeId, MessageHandler>,
 299    teardown: watch::Sender<bool>,
 300}
 301
 302pub(crate) struct ConnectionPoolGuard<'a> {
 303    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 304    _not_send: PhantomData<Rc<()>>,
 305}
 306
 307#[derive(Serialize)]
 308pub struct ServerSnapshot<'a> {
 309    peer: &'a Peer,
 310    #[serde(serialize_with = "serialize_deref")]
 311    connection_pool: ConnectionPoolGuard<'a>,
 312}
 313
 314pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 315where
 316    S: Serializer,
 317    T: Deref<Target = U>,
 318    U: Serialize,
 319{
 320    Serialize::serialize(value.deref(), serializer)
 321}
 322
 323impl Server {
 324    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 325        let mut server = Self {
 326            id: parking_lot::Mutex::new(id),
 327            peer: Peer::new(id.0 as u32),
 328            app_state: app_state.clone(),
 329            connection_pool: Default::default(),
 330            handlers: Default::default(),
 331            teardown: watch::channel(false).0,
 332        };
 333
 334        server
 335            .add_request_handler(ping)
 336            .add_request_handler(create_room)
 337            .add_request_handler(join_room)
 338            .add_request_handler(rejoin_room)
 339            .add_request_handler(leave_room)
 340            .add_request_handler(set_room_participant_role)
 341            .add_request_handler(call)
 342            .add_request_handler(cancel_call)
 343            .add_message_handler(decline_call)
 344            .add_request_handler(update_participant_location)
 345            .add_request_handler(share_project)
 346            .add_message_handler(unshare_project)
 347            .add_request_handler(join_project)
 348            .add_message_handler(leave_project)
 349            .add_request_handler(update_project)
 350            .add_request_handler(update_worktree)
 351            .add_request_handler(update_repository)
 352            .add_request_handler(remove_repository)
 353            .add_message_handler(start_language_server)
 354            .add_message_handler(update_language_server)
 355            .add_message_handler(update_diagnostic_summary)
 356            .add_message_handler(update_worktree_settings)
 357            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 358            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 359            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 360            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 361            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 362            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 363            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 364            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 365            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 366            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 367            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 368            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 369            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 370            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 371            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 372            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 373            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 374            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 375            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 376            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 377            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 378            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 379            .add_request_handler(
 380                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 381            )
 382            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 383            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 384            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 385            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 386            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 387            .add_request_handler(
 388                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 389            )
 390            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 391            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 392            .add_request_handler(
 393                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 394            )
 395            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 396            .add_request_handler(
 397                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 398            )
 399            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 400            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 401            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 402            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 403            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 404            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 405            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 406            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 407            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 408            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 409            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 410            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 411            .add_request_handler(
 412                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 413            )
 414            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 415            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 416            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 417            .add_request_handler(multi_lsp_query)
 418            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 419            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 420            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 421            .add_message_handler(create_buffer_for_peer)
 422            .add_request_handler(update_buffer)
 423            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 424            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 425            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 426            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 427            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 428            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 429            .add_message_handler(
 430                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 431            )
 432            .add_request_handler(get_users)
 433            .add_request_handler(fuzzy_search_users)
 434            .add_request_handler(request_contact)
 435            .add_request_handler(remove_contact)
 436            .add_request_handler(respond_to_contact_request)
 437            .add_message_handler(subscribe_to_channels)
 438            .add_request_handler(create_channel)
 439            .add_request_handler(delete_channel)
 440            .add_request_handler(invite_channel_member)
 441            .add_request_handler(remove_channel_member)
 442            .add_request_handler(set_channel_member_role)
 443            .add_request_handler(set_channel_visibility)
 444            .add_request_handler(rename_channel)
 445            .add_request_handler(join_channel_buffer)
 446            .add_request_handler(leave_channel_buffer)
 447            .add_message_handler(update_channel_buffer)
 448            .add_request_handler(rejoin_channel_buffers)
 449            .add_request_handler(get_channel_members)
 450            .add_request_handler(respond_to_channel_invite)
 451            .add_request_handler(join_channel)
 452            .add_request_handler(join_channel_chat)
 453            .add_message_handler(leave_channel_chat)
 454            .add_request_handler(send_channel_message)
 455            .add_request_handler(remove_channel_message)
 456            .add_request_handler(update_channel_message)
 457            .add_request_handler(get_channel_messages)
 458            .add_request_handler(get_channel_messages_by_id)
 459            .add_request_handler(get_notifications)
 460            .add_request_handler(mark_notification_as_read)
 461            .add_request_handler(move_channel)
 462            .add_request_handler(reorder_channel)
 463            .add_request_handler(follow)
 464            .add_message_handler(unfollow)
 465            .add_message_handler(update_followers)
 466            .add_request_handler(get_private_user_info)
 467            .add_request_handler(get_llm_api_token)
 468            .add_request_handler(accept_terms_of_service)
 469            .add_message_handler(acknowledge_channel_message)
 470            .add_message_handler(acknowledge_buffer_version)
 471            .add_request_handler(get_supermaven_api_key)
 472            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 473            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 474            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 475            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 476            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 477            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 478            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 479            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 480            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 481            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 482            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 483            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 484            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 485            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 486            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 487            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 488            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 489            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 490            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 491            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 492            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 493            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 494            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 495            .add_message_handler(update_context);
 496
 497        Arc::new(server)
 498    }
 499
 500    pub async fn start(&self) -> Result<()> {
 501        let server_id = *self.id.lock();
 502        let app_state = self.app_state.clone();
 503        let peer = self.peer.clone();
 504        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 505        let pool = self.connection_pool.clone();
 506        let livekit_client = self.app_state.livekit_client.clone();
 507
 508        let span = info_span!("start server");
 509        self.app_state.executor.spawn_detached(
 510            async move {
 511                tracing::info!("waiting for cleanup timeout");
 512                timeout.await;
 513                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 514
 515                app_state
 516                    .db
 517                    .delete_stale_channel_chat_participants(
 518                        &app_state.config.zed_environment,
 519                        server_id,
 520                    )
 521                    .await
 522                    .trace_err();
 523
 524                if let Some((room_ids, channel_ids)) = app_state
 525                    .db
 526                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 527                    .await
 528                    .trace_err()
 529                {
 530                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 531                    tracing::info!(
 532                        stale_channel_buffer_count = channel_ids.len(),
 533                        "retrieved stale channel buffers"
 534                    );
 535
 536                    for channel_id in channel_ids {
 537                        if let Some(refreshed_channel_buffer) = app_state
 538                            .db
 539                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 540                            .await
 541                            .trace_err()
 542                        {
 543                            for connection_id in refreshed_channel_buffer.connection_ids {
 544                                peer.send(
 545                                    connection_id,
 546                                    proto::UpdateChannelBufferCollaborators {
 547                                        channel_id: channel_id.to_proto(),
 548                                        collaborators: refreshed_channel_buffer
 549                                            .collaborators
 550                                            .clone(),
 551                                    },
 552                                )
 553                                .trace_err();
 554                            }
 555                        }
 556                    }
 557
 558                    for room_id in room_ids {
 559                        let mut contacts_to_update = HashSet::default();
 560                        let mut canceled_calls_to_user_ids = Vec::new();
 561                        let mut livekit_room = String::new();
 562                        let mut delete_livekit_room = false;
 563
 564                        if let Some(mut refreshed_room) = app_state
 565                            .db
 566                            .clear_stale_room_participants(room_id, server_id)
 567                            .await
 568                            .trace_err()
 569                        {
 570                            tracing::info!(
 571                                room_id = room_id.0,
 572                                new_participant_count = refreshed_room.room.participants.len(),
 573                                "refreshed room"
 574                            );
 575                            room_updated(&refreshed_room.room, &peer);
 576                            if let Some(channel) = refreshed_room.channel.as_ref() {
 577                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 578                            }
 579                            contacts_to_update
 580                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 581                            contacts_to_update
 582                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 583                            canceled_calls_to_user_ids =
 584                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 585                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 586                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 587                        }
 588
 589                        {
 590                            let pool = pool.lock();
 591                            for canceled_user_id in canceled_calls_to_user_ids {
 592                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 593                                    peer.send(
 594                                        connection_id,
 595                                        proto::CallCanceled {
 596                                            room_id: room_id.to_proto(),
 597                                        },
 598                                    )
 599                                    .trace_err();
 600                                }
 601                            }
 602                        }
 603
 604                        for user_id in contacts_to_update {
 605                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 606                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 607                            if let Some((busy, contacts)) = busy.zip(contacts) {
 608                                let pool = pool.lock();
 609                                let updated_contact = contact_for_user(user_id, busy, &pool);
 610                                for contact in contacts {
 611                                    if let db::Contact::Accepted {
 612                                        user_id: contact_user_id,
 613                                        ..
 614                                    } = contact
 615                                    {
 616                                        for contact_conn_id in
 617                                            pool.user_connection_ids(contact_user_id)
 618                                        {
 619                                            peer.send(
 620                                                contact_conn_id,
 621                                                proto::UpdateContacts {
 622                                                    contacts: vec![updated_contact.clone()],
 623                                                    remove_contacts: Default::default(),
 624                                                    incoming_requests: Default::default(),
 625                                                    remove_incoming_requests: Default::default(),
 626                                                    outgoing_requests: Default::default(),
 627                                                    remove_outgoing_requests: Default::default(),
 628                                                },
 629                                            )
 630                                            .trace_err();
 631                                        }
 632                                    }
 633                                }
 634                            }
 635                        }
 636
 637                        if let Some(live_kit) = livekit_client.as_ref() {
 638                            if delete_livekit_room {
 639                                live_kit.delete_room(livekit_room).await.trace_err();
 640                            }
 641                        }
 642                    }
 643                }
 644
 645                app_state
 646                    .db
 647                    .delete_stale_channel_chat_participants(
 648                        &app_state.config.zed_environment,
 649                        server_id,
 650                    )
 651                    .await
 652                    .trace_err();
 653
 654                app_state
 655                    .db
 656                    .clear_old_worktree_entries(server_id)
 657                    .await
 658                    .trace_err();
 659
 660                app_state
 661                    .db
 662                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 663                    .await
 664                    .trace_err();
 665            }
 666            .instrument(span),
 667        );
 668        Ok(())
 669    }
 670
 671    pub fn teardown(&self) {
 672        self.peer.teardown();
 673        self.connection_pool.lock().reset();
 674        let _ = self.teardown.send(true);
 675    }
 676
 677    #[cfg(test)]
 678    pub fn reset(&self, id: ServerId) {
 679        self.teardown();
 680        *self.id.lock() = id;
 681        self.peer.reset(id.0 as u32);
 682        let _ = self.teardown.send(false);
 683    }
 684
 685    #[cfg(test)]
 686    pub fn id(&self) -> ServerId {
 687        *self.id.lock()
 688    }
 689
 690    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 691    where
 692        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 693        Fut: 'static + Send + Future<Output = Result<()>>,
 694        M: EnvelopedMessage,
 695    {
 696        let prev_handler = self.handlers.insert(
 697            TypeId::of::<M>(),
 698            Box::new(move |envelope, session, span| {
 699                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 700                let received_at = envelope.received_at;
 701                tracing::info!("message received");
 702                let start_time = Instant::now();
 703                let future = (handler)(
 704                    *envelope,
 705                    MessageContext {
 706                        session,
 707                        span: span.clone(),
 708                    },
 709                );
 710                async move {
 711                    let result = future.await;
 712                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 713                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 714                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 715                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 716                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 717                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 718                    match result {
 719                        Err(error) => {
 720                            tracing::error!(?error, "error handling message")
 721                        }
 722                        Ok(()) => tracing::info!("finished handling message"),
 723                    }
 724                }
 725                .boxed()
 726            }),
 727        );
 728        if prev_handler.is_some() {
 729            panic!("registered a handler for the same message twice");
 730        }
 731        self
 732    }
 733
 734    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 735    where
 736        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 737        Fut: 'static + Send + Future<Output = Result<()>>,
 738        M: EnvelopedMessage,
 739    {
 740        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 741        self
 742    }
 743
 744    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 745    where
 746        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 747        Fut: Send + Future<Output = Result<()>>,
 748        M: RequestMessage,
 749    {
 750        let handler = Arc::new(handler);
 751        self.add_handler(move |envelope, session| {
 752            let receipt = envelope.receipt();
 753            let handler = handler.clone();
 754            async move {
 755                let peer = session.peer.clone();
 756                let responded = Arc::new(AtomicBool::default());
 757                let response = Response {
 758                    peer: peer.clone(),
 759                    responded: responded.clone(),
 760                    receipt,
 761                };
 762                match (handler)(envelope.payload, response, session).await {
 763                    Ok(()) => {
 764                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 765                            Ok(())
 766                        } else {
 767                            Err(anyhow!("handler did not send a response"))?
 768                        }
 769                    }
 770                    Err(error) => {
 771                        let proto_err = match &error {
 772                            Error::Internal(err) => err.to_proto(),
 773                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 774                        };
 775                        peer.respond_with_error(receipt, proto_err)?;
 776                        Err(error)
 777                    }
 778                }
 779            }
 780        })
 781    }
 782
 783    pub fn handle_connection(
 784        self: &Arc<Self>,
 785        connection: Connection,
 786        address: String,
 787        principal: Principal,
 788        zed_version: ZedVersion,
 789        release_channel: Option<String>,
 790        user_agent: Option<String>,
 791        geoip_country_code: Option<String>,
 792        system_id: Option<String>,
 793        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 794        executor: Executor,
 795        connection_guard: Option<ConnectionGuard>,
 796    ) -> impl Future<Output = ()> + use<> {
 797        let this = self.clone();
 798        let span = info_span!("handle connection", %address,
 799            connection_id=field::Empty,
 800            user_id=field::Empty,
 801            login=field::Empty,
 802            impersonator=field::Empty,
 803            user_agent=field::Empty,
 804            geoip_country_code=field::Empty,
 805            release_channel=field::Empty,
 806        );
 807        principal.update_span(&span);
 808        if let Some(user_agent) = user_agent {
 809            span.record("user_agent", user_agent);
 810        }
 811        if let Some(release_channel) = release_channel {
 812            span.record("release_channel", release_channel);
 813        }
 814
 815        if let Some(country_code) = geoip_country_code.as_ref() {
 816            span.record("geoip_country_code", country_code);
 817        }
 818
 819        let mut teardown = self.teardown.subscribe();
 820        async move {
 821            if *teardown.borrow() {
 822                tracing::error!("server is tearing down");
 823                return;
 824            }
 825
 826            let (connection_id, handle_io, mut incoming_rx) =
 827                this.peer.add_connection(connection, {
 828                    let executor = executor.clone();
 829                    move |duration| executor.sleep(duration)
 830                });
 831            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 832
 833            tracing::info!("connection opened");
 834
 835            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 836            let http_client = match ReqwestClient::user_agent(&user_agent) {
 837                Ok(http_client) => Arc::new(http_client),
 838                Err(error) => {
 839                    tracing::error!(?error, "failed to create HTTP client");
 840                    return;
 841                }
 842            };
 843
 844            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 845                |supermaven_admin_api_key| {
 846                    Arc::new(SupermavenAdminApi::new(
 847                        supermaven_admin_api_key.to_string(),
 848                        http_client.clone(),
 849                    ))
 850                },
 851            );
 852
 853            let session = Session {
 854                principal: principal.clone(),
 855                connection_id,
 856                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 857                peer: this.peer.clone(),
 858                connection_pool: this.connection_pool.clone(),
 859                app_state: this.app_state.clone(),
 860                geoip_country_code,
 861                system_id,
 862                _executor: executor.clone(),
 863                supermaven_client,
 864            };
 865
 866            if let Err(error) = this
 867                .send_initial_client_update(
 868                    connection_id,
 869                    zed_version,
 870                    send_connection_id,
 871                    &session,
 872                )
 873                .await
 874            {
 875                tracing::error!(?error, "failed to send initial client update");
 876                return;
 877            }
 878            drop(connection_guard);
 879
 880            let handle_io = handle_io.fuse();
 881            futures::pin_mut!(handle_io);
 882
 883            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 884            // This prevents deadlocks when e.g., client A performs a request to client B and
 885            // client B performs a request to client A. If both clients stop processing further
 886            // messages until their respective request completes, they won't have a chance to
 887            // respond to the other client's request and cause a deadlock.
 888            //
 889            // This arrangement ensures we will attempt to process earlier messages first, but fall
 890            // back to processing messages arrived later in the spirit of making progress.
 891            const MAX_CONCURRENT_HANDLERS: usize = 256;
 892            let mut foreground_message_handlers = FuturesUnordered::new();
 893            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 894            let get_concurrent_handlers = {
 895                let concurrent_handlers = concurrent_handlers.clone();
 896                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 897            };
 898            loop {
 899                let next_message = async {
 900                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 901                    let message = incoming_rx.next().await;
 902                    // Cache the concurrent_handlers here, so that we know what the
 903                    // queue looks like as each handler starts
 904                    (permit, message, get_concurrent_handlers())
 905                }
 906                .fuse();
 907                futures::pin_mut!(next_message);
 908                futures::select_biased! {
 909                    _ = teardown.changed().fuse() => return,
 910                    result = handle_io => {
 911                        if let Err(error) = result {
 912                            tracing::error!(?error, "error handling I/O");
 913                        }
 914                        break;
 915                    }
 916                    _ = foreground_message_handlers.next() => {}
 917                    next_message = next_message => {
 918                        let (permit, message, concurrent_handlers) = next_message;
 919                        if let Some(message) = message {
 920                            let type_name = message.payload_type_name();
 921                            // note: we copy all the fields from the parent span so we can query them in the logs.
 922                            // (https://github.com/tokio-rs/tracing/issues/2670).
 923                            let span = tracing::info_span!("receive message",
 924                                %connection_id,
 925                                %address,
 926                                type_name,
 927                                concurrent_handlers,
 928                                user_id=field::Empty,
 929                                login=field::Empty,
 930                                impersonator=field::Empty,
 931                                multi_lsp_query_request=field::Empty,
 932                                release_channel=field::Empty,
 933                                { TOTAL_DURATION_MS }=field::Empty,
 934                                { PROCESSING_DURATION_MS }=field::Empty,
 935                                { QUEUE_DURATION_MS }=field::Empty,
 936                                { HOST_WAITING_MS }=field::Empty
 937                            );
 938                            principal.update_span(&span);
 939                            let span_enter = span.enter();
 940                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 941                                let is_background = message.is_background();
 942                                let handle_message = (handler)(message, session.clone(), span.clone());
 943                                drop(span_enter);
 944
 945                                let handle_message = async move {
 946                                    handle_message.await;
 947                                    drop(permit);
 948                                }.instrument(span);
 949                                if is_background {
 950                                    executor.spawn_detached(handle_message);
 951                                } else {
 952                                    foreground_message_handlers.push(handle_message);
 953                                }
 954                            } else {
 955                                tracing::error!("no message handler");
 956                            }
 957                        } else {
 958                            tracing::info!("connection closed");
 959                            break;
 960                        }
 961                    }
 962                }
 963            }
 964
 965            drop(foreground_message_handlers);
 966            let concurrent_handlers = get_concurrent_handlers();
 967            tracing::info!(concurrent_handlers, "signing out");
 968            if let Err(error) = connection_lost(session, teardown, executor).await {
 969                tracing::error!(?error, "error signing out");
 970            }
 971        }
 972        .instrument(span)
 973    }
 974
 975    async fn send_initial_client_update(
 976        &self,
 977        connection_id: ConnectionId,
 978        zed_version: ZedVersion,
 979        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 980        session: &Session,
 981    ) -> Result<()> {
 982        self.peer.send(
 983            connection_id,
 984            proto::Hello {
 985                peer_id: Some(connection_id.into()),
 986            },
 987        )?;
 988        tracing::info!("sent hello message");
 989        if let Some(send_connection_id) = send_connection_id.take() {
 990            let _ = send_connection_id.send(connection_id);
 991        }
 992
 993        match &session.principal {
 994            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 995                if !user.connected_once {
 996                    self.peer.send(connection_id, proto::ShowContacts {})?;
 997                    self.app_state
 998                        .db
 999                        .set_user_connected_once(user.id, true)
1000                        .await?;
1001                }
1002
1003                update_user_plan(session).await?;
1004
1005                let contacts = self.app_state.db.get_contacts(user.id).await?;
1006
1007                {
1008                    let mut pool = self.connection_pool.lock();
1009                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1010                    self.peer.send(
1011                        connection_id,
1012                        build_initial_contacts_update(contacts, &pool),
1013                    )?;
1014                }
1015
1016                if should_auto_subscribe_to_channels(zed_version) {
1017                    subscribe_user_to_channels(user.id, session).await?;
1018                }
1019
1020                if let Some(incoming_call) =
1021                    self.app_state.db.incoming_call_for_user(user.id).await?
1022                {
1023                    self.peer.send(connection_id, incoming_call)?;
1024                }
1025
1026                update_user_contacts(user.id, session).await?;
1027            }
1028        }
1029
1030        Ok(())
1031    }
1032
1033    pub async fn invite_code_redeemed(
1034        self: &Arc<Self>,
1035        inviter_id: UserId,
1036        invitee_id: UserId,
1037    ) -> Result<()> {
1038        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1039            if let Some(code) = &user.invite_code {
1040                let pool = self.connection_pool.lock();
1041                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1042                for connection_id in pool.user_connection_ids(inviter_id) {
1043                    self.peer.send(
1044                        connection_id,
1045                        proto::UpdateContacts {
1046                            contacts: vec![invitee_contact.clone()],
1047                            ..Default::default()
1048                        },
1049                    )?;
1050                    self.peer.send(
1051                        connection_id,
1052                        proto::UpdateInviteInfo {
1053                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1054                            count: user.invite_count as u32,
1055                        },
1056                    )?;
1057                }
1058            }
1059        }
1060        Ok(())
1061    }
1062
1063    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1064        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1065            if let Some(invite_code) = &user.invite_code {
1066                let pool = self.connection_pool.lock();
1067                for connection_id in pool.user_connection_ids(user_id) {
1068                    self.peer.send(
1069                        connection_id,
1070                        proto::UpdateInviteInfo {
1071                            url: format!(
1072                                "{}{}",
1073                                self.app_state.config.invite_link_prefix, invite_code
1074                            ),
1075                            count: user.invite_count as u32,
1076                        },
1077                    )?;
1078                }
1079            }
1080        }
1081        Ok(())
1082    }
1083
1084    pub async fn update_plan_for_user(
1085        self: &Arc<Self>,
1086        user_id: UserId,
1087        update_user_plan: proto::UpdateUserPlan,
1088    ) -> Result<()> {
1089        let pool = self.connection_pool.lock();
1090        for connection_id in pool.user_connection_ids(user_id) {
1091            self.peer
1092                .send(connection_id, update_user_plan.clone())
1093                .trace_err();
1094        }
1095
1096        Ok(())
1097    }
1098
1099    /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1100    /// message on the Collab server.
1101    ///
1102    /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1103    pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1104        let user = self
1105            .app_state
1106            .db
1107            .get_user_by_id(user_id)
1108            .await?
1109            .context("user not found")?;
1110
1111        let update_user_plan = make_update_user_plan_message(
1112            &user,
1113            user.admin,
1114            &self.app_state.db,
1115            self.app_state.llm_db.clone(),
1116        )
1117        .await?;
1118
1119        self.update_plan_for_user(user_id, update_user_plan).await
1120    }
1121
1122    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1123        let pool = self.connection_pool.lock();
1124        for connection_id in pool.user_connection_ids(user_id) {
1125            self.peer
1126                .send(connection_id, proto::RefreshLlmToken {})
1127                .trace_err();
1128        }
1129    }
1130
1131    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1132        ServerSnapshot {
1133            connection_pool: ConnectionPoolGuard {
1134                guard: self.connection_pool.lock(),
1135                _not_send: PhantomData,
1136            },
1137            peer: &self.peer,
1138        }
1139    }
1140}
1141
1142impl Deref for ConnectionPoolGuard<'_> {
1143    type Target = ConnectionPool;
1144
1145    fn deref(&self) -> &Self::Target {
1146        &self.guard
1147    }
1148}
1149
1150impl DerefMut for ConnectionPoolGuard<'_> {
1151    fn deref_mut(&mut self) -> &mut Self::Target {
1152        &mut self.guard
1153    }
1154}
1155
1156impl Drop for ConnectionPoolGuard<'_> {
1157    fn drop(&mut self) {
1158        #[cfg(test)]
1159        self.check_invariants();
1160    }
1161}
1162
1163fn broadcast<F>(
1164    sender_id: Option<ConnectionId>,
1165    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1166    mut f: F,
1167) where
1168    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1169{
1170    for receiver_id in receiver_ids {
1171        if Some(receiver_id) != sender_id {
1172            if let Err(error) = f(receiver_id) {
1173                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1174            }
1175        }
1176    }
1177}
1178
1179pub struct ProtocolVersion(u32);
1180
1181impl Header for ProtocolVersion {
1182    fn name() -> &'static HeaderName {
1183        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1184        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1185    }
1186
1187    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1188    where
1189        Self: Sized,
1190        I: Iterator<Item = &'i axum::http::HeaderValue>,
1191    {
1192        let version = values
1193            .next()
1194            .ok_or_else(axum::headers::Error::invalid)?
1195            .to_str()
1196            .map_err(|_| axum::headers::Error::invalid())?
1197            .parse()
1198            .map_err(|_| axum::headers::Error::invalid())?;
1199        Ok(Self(version))
1200    }
1201
1202    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1203        values.extend([self.0.to_string().parse().unwrap()]);
1204    }
1205}
1206
1207pub struct AppVersionHeader(SemanticVersion);
1208impl Header for AppVersionHeader {
1209    fn name() -> &'static HeaderName {
1210        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1211        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1212    }
1213
1214    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1215    where
1216        Self: Sized,
1217        I: Iterator<Item = &'i axum::http::HeaderValue>,
1218    {
1219        let version = values
1220            .next()
1221            .ok_or_else(axum::headers::Error::invalid)?
1222            .to_str()
1223            .map_err(|_| axum::headers::Error::invalid())?
1224            .parse()
1225            .map_err(|_| axum::headers::Error::invalid())?;
1226        Ok(Self(version))
1227    }
1228
1229    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1230        values.extend([self.0.to_string().parse().unwrap()]);
1231    }
1232}
1233
1234#[derive(Debug)]
1235pub struct ReleaseChannelHeader(String);
1236
1237impl Header for ReleaseChannelHeader {
1238    fn name() -> &'static HeaderName {
1239        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1240        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1241    }
1242
1243    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1244    where
1245        Self: Sized,
1246        I: Iterator<Item = &'i axum::http::HeaderValue>,
1247    {
1248        Ok(Self(
1249            values
1250                .next()
1251                .ok_or_else(axum::headers::Error::invalid)?
1252                .to_str()
1253                .map_err(|_| axum::headers::Error::invalid())?
1254                .to_owned(),
1255        ))
1256    }
1257
1258    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1259        values.extend([self.0.parse().unwrap()]);
1260    }
1261}
1262
1263pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1264    Router::new()
1265        .route("/rpc", get(handle_websocket_request))
1266        .layer(
1267            ServiceBuilder::new()
1268                .layer(Extension(server.app_state.clone()))
1269                .layer(middleware::from_fn(auth::validate_header)),
1270        )
1271        .route("/metrics", get(handle_metrics))
1272        .layer(Extension(server))
1273}
1274
1275pub async fn handle_websocket_request(
1276    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1277    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1278    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1279    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1280    Extension(server): Extension<Arc<Server>>,
1281    Extension(principal): Extension<Principal>,
1282    user_agent: Option<TypedHeader<UserAgent>>,
1283    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1284    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1285    ws: WebSocketUpgrade,
1286) -> axum::response::Response {
1287    if protocol_version != rpc::PROTOCOL_VERSION {
1288        return (
1289            StatusCode::UPGRADE_REQUIRED,
1290            "client must be upgraded".to_string(),
1291        )
1292            .into_response();
1293    }
1294
1295    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1296        return (
1297            StatusCode::UPGRADE_REQUIRED,
1298            "no version header found".to_string(),
1299        )
1300            .into_response();
1301    };
1302
1303    let release_channel = release_channel_header.map(|header| header.0.0);
1304
1305    if !version.can_collaborate() {
1306        return (
1307            StatusCode::UPGRADE_REQUIRED,
1308            "client must be upgraded".to_string(),
1309        )
1310            .into_response();
1311    }
1312
1313    let socket_address = socket_address.to_string();
1314
1315    // Acquire connection guard before WebSocket upgrade
1316    let connection_guard = match ConnectionGuard::try_acquire() {
1317        Ok(guard) => guard,
1318        Err(()) => {
1319            return (
1320                StatusCode::SERVICE_UNAVAILABLE,
1321                "Too many concurrent connections",
1322            )
1323                .into_response();
1324        }
1325    };
1326
1327    ws.on_upgrade(move |socket| {
1328        let socket = socket
1329            .map_ok(to_tungstenite_message)
1330            .err_into()
1331            .with(|message| async move { to_axum_message(message) });
1332        let connection = Connection::new(Box::pin(socket));
1333        async move {
1334            server
1335                .handle_connection(
1336                    connection,
1337                    socket_address,
1338                    principal,
1339                    version,
1340                    release_channel,
1341                    user_agent.map(|header| header.to_string()),
1342                    country_code_header.map(|header| header.to_string()),
1343                    system_id_header.map(|header| header.to_string()),
1344                    None,
1345                    Executor::Production,
1346                    Some(connection_guard),
1347                )
1348                .await;
1349        }
1350    })
1351}
1352
1353pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1354    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1355    let connections_metric = CONNECTIONS_METRIC
1356        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1357
1358    let connections = server
1359        .connection_pool
1360        .lock()
1361        .connections()
1362        .filter(|connection| !connection.admin)
1363        .count();
1364    connections_metric.set(connections as _);
1365
1366    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1367    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1368        register_int_gauge!(
1369            "shared_projects",
1370            "number of open projects with one or more guests"
1371        )
1372        .unwrap()
1373    });
1374
1375    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1376    shared_projects_metric.set(shared_projects as _);
1377
1378    let encoder = prometheus::TextEncoder::new();
1379    let metric_families = prometheus::gather();
1380    let encoded_metrics = encoder
1381        .encode_to_string(&metric_families)
1382        .map_err(|err| anyhow!("{err}"))?;
1383    Ok(encoded_metrics)
1384}
1385
1386#[instrument(err, skip(executor))]
1387async fn connection_lost(
1388    session: Session,
1389    mut teardown: watch::Receiver<bool>,
1390    executor: Executor,
1391) -> Result<()> {
1392    session.peer.disconnect(session.connection_id);
1393    session
1394        .connection_pool()
1395        .await
1396        .remove_connection(session.connection_id)?;
1397
1398    session
1399        .db()
1400        .await
1401        .connection_lost(session.connection_id)
1402        .await
1403        .trace_err();
1404
1405    futures::select_biased! {
1406        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1407
1408            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1409            leave_room_for_session(&session, session.connection_id).await.trace_err();
1410            leave_channel_buffers_for_session(&session)
1411                .await
1412                .trace_err();
1413
1414            if !session
1415                .connection_pool()
1416                .await
1417                .is_user_online(session.user_id())
1418            {
1419                let db = session.db().await;
1420                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1421                    room_updated(&room, &session.peer);
1422                }
1423            }
1424
1425            update_user_contacts(session.user_id(), &session).await?;
1426        },
1427        _ = teardown.changed().fuse() => {}
1428    }
1429
1430    Ok(())
1431}
1432
1433/// Acknowledges a ping from a client, used to keep the connection alive.
1434async fn ping(
1435    _: proto::Ping,
1436    response: Response<proto::Ping>,
1437    _session: MessageContext,
1438) -> Result<()> {
1439    response.send(proto::Ack {})?;
1440    Ok(())
1441}
1442
1443/// Creates a new room for calling (outside of channels)
1444async fn create_room(
1445    _request: proto::CreateRoom,
1446    response: Response<proto::CreateRoom>,
1447    session: MessageContext,
1448) -> Result<()> {
1449    let livekit_room = nanoid::nanoid!(30);
1450
1451    let live_kit_connection_info = util::maybe!(async {
1452        let live_kit = session.app_state.livekit_client.as_ref();
1453        let live_kit = live_kit?;
1454        let user_id = session.user_id().to_string();
1455
1456        let token = live_kit
1457            .room_token(&livekit_room, &user_id.to_string())
1458            .trace_err()?;
1459
1460        Some(proto::LiveKitConnectionInfo {
1461            server_url: live_kit.url().into(),
1462            token,
1463            can_publish: true,
1464        })
1465    })
1466    .await;
1467
1468    let room = session
1469        .db()
1470        .await
1471        .create_room(session.user_id(), session.connection_id, &livekit_room)
1472        .await?;
1473
1474    response.send(proto::CreateRoomResponse {
1475        room: Some(room.clone()),
1476        live_kit_connection_info,
1477    })?;
1478
1479    update_user_contacts(session.user_id(), &session).await?;
1480    Ok(())
1481}
1482
1483/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1484async fn join_room(
1485    request: proto::JoinRoom,
1486    response: Response<proto::JoinRoom>,
1487    session: MessageContext,
1488) -> Result<()> {
1489    let room_id = RoomId::from_proto(request.id);
1490
1491    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1492
1493    if let Some(channel_id) = channel_id {
1494        return join_channel_internal(channel_id, Box::new(response), session).await;
1495    }
1496
1497    let joined_room = {
1498        let room = session
1499            .db()
1500            .await
1501            .join_room(room_id, session.user_id(), session.connection_id)
1502            .await?;
1503        room_updated(&room.room, &session.peer);
1504        room.into_inner()
1505    };
1506
1507    for connection_id in session
1508        .connection_pool()
1509        .await
1510        .user_connection_ids(session.user_id())
1511    {
1512        session
1513            .peer
1514            .send(
1515                connection_id,
1516                proto::CallCanceled {
1517                    room_id: room_id.to_proto(),
1518                },
1519            )
1520            .trace_err();
1521    }
1522
1523    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1524    {
1525        live_kit
1526            .room_token(
1527                &joined_room.room.livekit_room,
1528                &session.user_id().to_string(),
1529            )
1530            .trace_err()
1531            .map(|token| proto::LiveKitConnectionInfo {
1532                server_url: live_kit.url().into(),
1533                token,
1534                can_publish: true,
1535            })
1536    } else {
1537        None
1538    };
1539
1540    response.send(proto::JoinRoomResponse {
1541        room: Some(joined_room.room),
1542        channel_id: None,
1543        live_kit_connection_info,
1544    })?;
1545
1546    update_user_contacts(session.user_id(), &session).await?;
1547    Ok(())
1548}
1549
1550/// Rejoin room is used to reconnect to a room after connection errors.
1551async fn rejoin_room(
1552    request: proto::RejoinRoom,
1553    response: Response<proto::RejoinRoom>,
1554    session: MessageContext,
1555) -> Result<()> {
1556    let room;
1557    let channel;
1558    {
1559        let mut rejoined_room = session
1560            .db()
1561            .await
1562            .rejoin_room(request, session.user_id(), session.connection_id)
1563            .await?;
1564
1565        response.send(proto::RejoinRoomResponse {
1566            room: Some(rejoined_room.room.clone()),
1567            reshared_projects: rejoined_room
1568                .reshared_projects
1569                .iter()
1570                .map(|project| proto::ResharedProject {
1571                    id: project.id.to_proto(),
1572                    collaborators: project
1573                        .collaborators
1574                        .iter()
1575                        .map(|collaborator| collaborator.to_proto())
1576                        .collect(),
1577                })
1578                .collect(),
1579            rejoined_projects: rejoined_room
1580                .rejoined_projects
1581                .iter()
1582                .map(|rejoined_project| rejoined_project.to_proto())
1583                .collect(),
1584        })?;
1585        room_updated(&rejoined_room.room, &session.peer);
1586
1587        for project in &rejoined_room.reshared_projects {
1588            for collaborator in &project.collaborators {
1589                session
1590                    .peer
1591                    .send(
1592                        collaborator.connection_id,
1593                        proto::UpdateProjectCollaborator {
1594                            project_id: project.id.to_proto(),
1595                            old_peer_id: Some(project.old_connection_id.into()),
1596                            new_peer_id: Some(session.connection_id.into()),
1597                        },
1598                    )
1599                    .trace_err();
1600            }
1601
1602            broadcast(
1603                Some(session.connection_id),
1604                project
1605                    .collaborators
1606                    .iter()
1607                    .map(|collaborator| collaborator.connection_id),
1608                |connection_id| {
1609                    session.peer.forward_send(
1610                        session.connection_id,
1611                        connection_id,
1612                        proto::UpdateProject {
1613                            project_id: project.id.to_proto(),
1614                            worktrees: project.worktrees.clone(),
1615                        },
1616                    )
1617                },
1618            );
1619        }
1620
1621        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1622
1623        let rejoined_room = rejoined_room.into_inner();
1624
1625        room = rejoined_room.room;
1626        channel = rejoined_room.channel;
1627    }
1628
1629    if let Some(channel) = channel {
1630        channel_updated(
1631            &channel,
1632            &room,
1633            &session.peer,
1634            &*session.connection_pool().await,
1635        );
1636    }
1637
1638    update_user_contacts(session.user_id(), &session).await?;
1639    Ok(())
1640}
1641
1642fn notify_rejoined_projects(
1643    rejoined_projects: &mut Vec<RejoinedProject>,
1644    session: &Session,
1645) -> Result<()> {
1646    for project in rejoined_projects.iter() {
1647        for collaborator in &project.collaborators {
1648            session
1649                .peer
1650                .send(
1651                    collaborator.connection_id,
1652                    proto::UpdateProjectCollaborator {
1653                        project_id: project.id.to_proto(),
1654                        old_peer_id: Some(project.old_connection_id.into()),
1655                        new_peer_id: Some(session.connection_id.into()),
1656                    },
1657                )
1658                .trace_err();
1659        }
1660    }
1661
1662    for project in rejoined_projects {
1663        for worktree in mem::take(&mut project.worktrees) {
1664            // Stream this worktree's entries.
1665            let message = proto::UpdateWorktree {
1666                project_id: project.id.to_proto(),
1667                worktree_id: worktree.id,
1668                abs_path: worktree.abs_path.clone(),
1669                root_name: worktree.root_name,
1670                updated_entries: worktree.updated_entries,
1671                removed_entries: worktree.removed_entries,
1672                scan_id: worktree.scan_id,
1673                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1674                updated_repositories: worktree.updated_repositories,
1675                removed_repositories: worktree.removed_repositories,
1676            };
1677            for update in proto::split_worktree_update(message) {
1678                session.peer.send(session.connection_id, update)?;
1679            }
1680
1681            // Stream this worktree's diagnostics.
1682            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1683            if let Some(summary) = worktree_diagnostics.next() {
1684                let message = proto::UpdateDiagnosticSummary {
1685                    project_id: project.id.to_proto(),
1686                    worktree_id: worktree.id,
1687                    summary: Some(summary),
1688                    more_summaries: worktree_diagnostics.collect(),
1689                };
1690                session.peer.send(session.connection_id, message)?;
1691            }
1692
1693            for settings_file in worktree.settings_files {
1694                session.peer.send(
1695                    session.connection_id,
1696                    proto::UpdateWorktreeSettings {
1697                        project_id: project.id.to_proto(),
1698                        worktree_id: worktree.id,
1699                        path: settings_file.path,
1700                        content: Some(settings_file.content),
1701                        kind: Some(settings_file.kind.to_proto().into()),
1702                    },
1703                )?;
1704            }
1705        }
1706
1707        for repository in mem::take(&mut project.updated_repositories) {
1708            for update in split_repository_update(repository) {
1709                session.peer.send(session.connection_id, update)?;
1710            }
1711        }
1712
1713        for id in mem::take(&mut project.removed_repositories) {
1714            session.peer.send(
1715                session.connection_id,
1716                proto::RemoveRepository {
1717                    project_id: project.id.to_proto(),
1718                    id,
1719                },
1720            )?;
1721        }
1722    }
1723
1724    Ok(())
1725}
1726
1727/// leave room disconnects from the room.
1728async fn leave_room(
1729    _: proto::LeaveRoom,
1730    response: Response<proto::LeaveRoom>,
1731    session: MessageContext,
1732) -> Result<()> {
1733    leave_room_for_session(&session, session.connection_id).await?;
1734    response.send(proto::Ack {})?;
1735    Ok(())
1736}
1737
1738/// Updates the permissions of someone else in the room.
1739async fn set_room_participant_role(
1740    request: proto::SetRoomParticipantRole,
1741    response: Response<proto::SetRoomParticipantRole>,
1742    session: MessageContext,
1743) -> Result<()> {
1744    let user_id = UserId::from_proto(request.user_id);
1745    let role = ChannelRole::from(request.role());
1746
1747    let (livekit_room, can_publish) = {
1748        let room = session
1749            .db()
1750            .await
1751            .set_room_participant_role(
1752                session.user_id(),
1753                RoomId::from_proto(request.room_id),
1754                user_id,
1755                role,
1756            )
1757            .await?;
1758
1759        let livekit_room = room.livekit_room.clone();
1760        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1761        room_updated(&room, &session.peer);
1762        (livekit_room, can_publish)
1763    };
1764
1765    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1766        live_kit
1767            .update_participant(
1768                livekit_room.clone(),
1769                request.user_id.to_string(),
1770                livekit_api::proto::ParticipantPermission {
1771                    can_subscribe: true,
1772                    can_publish,
1773                    can_publish_data: can_publish,
1774                    hidden: false,
1775                    recorder: false,
1776                },
1777            )
1778            .await
1779            .trace_err();
1780    }
1781
1782    response.send(proto::Ack {})?;
1783    Ok(())
1784}
1785
1786/// Call someone else into the current room
1787async fn call(
1788    request: proto::Call,
1789    response: Response<proto::Call>,
1790    session: MessageContext,
1791) -> Result<()> {
1792    let room_id = RoomId::from_proto(request.room_id);
1793    let calling_user_id = session.user_id();
1794    let calling_connection_id = session.connection_id;
1795    let called_user_id = UserId::from_proto(request.called_user_id);
1796    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1797    if !session
1798        .db()
1799        .await
1800        .has_contact(calling_user_id, called_user_id)
1801        .await?
1802    {
1803        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1804    }
1805
1806    let incoming_call = {
1807        let (room, incoming_call) = &mut *session
1808            .db()
1809            .await
1810            .call(
1811                room_id,
1812                calling_user_id,
1813                calling_connection_id,
1814                called_user_id,
1815                initial_project_id,
1816            )
1817            .await?;
1818        room_updated(room, &session.peer);
1819        mem::take(incoming_call)
1820    };
1821    update_user_contacts(called_user_id, &session).await?;
1822
1823    let mut calls = session
1824        .connection_pool()
1825        .await
1826        .user_connection_ids(called_user_id)
1827        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1828        .collect::<FuturesUnordered<_>>();
1829
1830    while let Some(call_response) = calls.next().await {
1831        match call_response.as_ref() {
1832            Ok(_) => {
1833                response.send(proto::Ack {})?;
1834                return Ok(());
1835            }
1836            Err(_) => {
1837                call_response.trace_err();
1838            }
1839        }
1840    }
1841
1842    {
1843        let room = session
1844            .db()
1845            .await
1846            .call_failed(room_id, called_user_id)
1847            .await?;
1848        room_updated(&room, &session.peer);
1849    }
1850    update_user_contacts(called_user_id, &session).await?;
1851
1852    Err(anyhow!("failed to ring user"))?
1853}
1854
1855/// Cancel an outgoing call.
1856async fn cancel_call(
1857    request: proto::CancelCall,
1858    response: Response<proto::CancelCall>,
1859    session: MessageContext,
1860) -> Result<()> {
1861    let called_user_id = UserId::from_proto(request.called_user_id);
1862    let room_id = RoomId::from_proto(request.room_id);
1863    {
1864        let room = session
1865            .db()
1866            .await
1867            .cancel_call(room_id, session.connection_id, called_user_id)
1868            .await?;
1869        room_updated(&room, &session.peer);
1870    }
1871
1872    for connection_id in session
1873        .connection_pool()
1874        .await
1875        .user_connection_ids(called_user_id)
1876    {
1877        session
1878            .peer
1879            .send(
1880                connection_id,
1881                proto::CallCanceled {
1882                    room_id: room_id.to_proto(),
1883                },
1884            )
1885            .trace_err();
1886    }
1887    response.send(proto::Ack {})?;
1888
1889    update_user_contacts(called_user_id, &session).await?;
1890    Ok(())
1891}
1892
1893/// Decline an incoming call.
1894async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1895    let room_id = RoomId::from_proto(message.room_id);
1896    {
1897        let room = session
1898            .db()
1899            .await
1900            .decline_call(Some(room_id), session.user_id())
1901            .await?
1902            .context("declining call")?;
1903        room_updated(&room, &session.peer);
1904    }
1905
1906    for connection_id in session
1907        .connection_pool()
1908        .await
1909        .user_connection_ids(session.user_id())
1910    {
1911        session
1912            .peer
1913            .send(
1914                connection_id,
1915                proto::CallCanceled {
1916                    room_id: room_id.to_proto(),
1917                },
1918            )
1919            .trace_err();
1920    }
1921    update_user_contacts(session.user_id(), &session).await?;
1922    Ok(())
1923}
1924
1925/// Updates other participants in the room with your current location.
1926async fn update_participant_location(
1927    request: proto::UpdateParticipantLocation,
1928    response: Response<proto::UpdateParticipantLocation>,
1929    session: MessageContext,
1930) -> Result<()> {
1931    let room_id = RoomId::from_proto(request.room_id);
1932    let location = request.location.context("invalid location")?;
1933
1934    let db = session.db().await;
1935    let room = db
1936        .update_room_participant_location(room_id, session.connection_id, location)
1937        .await?;
1938
1939    room_updated(&room, &session.peer);
1940    response.send(proto::Ack {})?;
1941    Ok(())
1942}
1943
1944/// Share a project into the room.
1945async fn share_project(
1946    request: proto::ShareProject,
1947    response: Response<proto::ShareProject>,
1948    session: MessageContext,
1949) -> Result<()> {
1950    let (project_id, room) = &*session
1951        .db()
1952        .await
1953        .share_project(
1954            RoomId::from_proto(request.room_id),
1955            session.connection_id,
1956            &request.worktrees,
1957            request.is_ssh_project,
1958        )
1959        .await?;
1960    response.send(proto::ShareProjectResponse {
1961        project_id: project_id.to_proto(),
1962    })?;
1963    room_updated(room, &session.peer);
1964
1965    Ok(())
1966}
1967
1968/// Unshare a project from the room.
1969async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1970    let project_id = ProjectId::from_proto(message.project_id);
1971    unshare_project_internal(project_id, session.connection_id, &session).await
1972}
1973
1974async fn unshare_project_internal(
1975    project_id: ProjectId,
1976    connection_id: ConnectionId,
1977    session: &Session,
1978) -> Result<()> {
1979    let delete = {
1980        let room_guard = session
1981            .db()
1982            .await
1983            .unshare_project(project_id, connection_id)
1984            .await?;
1985
1986        let (delete, room, guest_connection_ids) = &*room_guard;
1987
1988        let message = proto::UnshareProject {
1989            project_id: project_id.to_proto(),
1990        };
1991
1992        broadcast(
1993            Some(connection_id),
1994            guest_connection_ids.iter().copied(),
1995            |conn_id| session.peer.send(conn_id, message.clone()),
1996        );
1997        if let Some(room) = room {
1998            room_updated(room, &session.peer);
1999        }
2000
2001        *delete
2002    };
2003
2004    if delete {
2005        let db = session.db().await;
2006        db.delete_project(project_id).await?;
2007    }
2008
2009    Ok(())
2010}
2011
2012/// Join someone elses shared project.
2013async fn join_project(
2014    request: proto::JoinProject,
2015    response: Response<proto::JoinProject>,
2016    session: MessageContext,
2017) -> Result<()> {
2018    let project_id = ProjectId::from_proto(request.project_id);
2019
2020    tracing::info!(%project_id, "join project");
2021
2022    let db = session.db().await;
2023    let (project, replica_id) = &mut *db
2024        .join_project(
2025            project_id,
2026            session.connection_id,
2027            session.user_id(),
2028            request.committer_name.clone(),
2029            request.committer_email.clone(),
2030        )
2031        .await?;
2032    drop(db);
2033    tracing::info!(%project_id, "join remote project");
2034    let collaborators = project
2035        .collaborators
2036        .iter()
2037        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2038        .map(|collaborator| collaborator.to_proto())
2039        .collect::<Vec<_>>();
2040    let project_id = project.id;
2041    let guest_user_id = session.user_id();
2042
2043    let worktrees = project
2044        .worktrees
2045        .iter()
2046        .map(|(id, worktree)| proto::WorktreeMetadata {
2047            id: *id,
2048            root_name: worktree.root_name.clone(),
2049            visible: worktree.visible,
2050            abs_path: worktree.abs_path.clone(),
2051        })
2052        .collect::<Vec<_>>();
2053
2054    let add_project_collaborator = proto::AddProjectCollaborator {
2055        project_id: project_id.to_proto(),
2056        collaborator: Some(proto::Collaborator {
2057            peer_id: Some(session.connection_id.into()),
2058            replica_id: replica_id.0 as u32,
2059            user_id: guest_user_id.to_proto(),
2060            is_host: false,
2061            committer_name: request.committer_name.clone(),
2062            committer_email: request.committer_email.clone(),
2063        }),
2064    };
2065
2066    for collaborator in &collaborators {
2067        session
2068            .peer
2069            .send(
2070                collaborator.peer_id.unwrap().into(),
2071                add_project_collaborator.clone(),
2072            )
2073            .trace_err();
2074    }
2075
2076    // First, we send the metadata associated with each worktree.
2077    let (language_servers, language_server_capabilities) = project
2078        .language_servers
2079        .clone()
2080        .into_iter()
2081        .map(|server| (server.server, server.capabilities))
2082        .unzip();
2083    response.send(proto::JoinProjectResponse {
2084        project_id: project.id.0 as u64,
2085        worktrees: worktrees.clone(),
2086        replica_id: replica_id.0 as u32,
2087        collaborators: collaborators.clone(),
2088        language_servers,
2089        language_server_capabilities,
2090        role: project.role.into(),
2091    })?;
2092
2093    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2094        // Stream this worktree's entries.
2095        let message = proto::UpdateWorktree {
2096            project_id: project_id.to_proto(),
2097            worktree_id,
2098            abs_path: worktree.abs_path.clone(),
2099            root_name: worktree.root_name,
2100            updated_entries: worktree.entries,
2101            removed_entries: Default::default(),
2102            scan_id: worktree.scan_id,
2103            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2104            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2105            removed_repositories: Default::default(),
2106        };
2107        for update in proto::split_worktree_update(message) {
2108            session.peer.send(session.connection_id, update.clone())?;
2109        }
2110
2111        // Stream this worktree's diagnostics.
2112        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2113        if let Some(summary) = worktree_diagnostics.next() {
2114            let message = proto::UpdateDiagnosticSummary {
2115                project_id: project.id.to_proto(),
2116                worktree_id: worktree.id,
2117                summary: Some(summary),
2118                more_summaries: worktree_diagnostics.collect(),
2119            };
2120            session.peer.send(session.connection_id, message)?;
2121        }
2122
2123        for settings_file in worktree.settings_files {
2124            session.peer.send(
2125                session.connection_id,
2126                proto::UpdateWorktreeSettings {
2127                    project_id: project_id.to_proto(),
2128                    worktree_id: worktree.id,
2129                    path: settings_file.path,
2130                    content: Some(settings_file.content),
2131                    kind: Some(settings_file.kind.to_proto() as i32),
2132                },
2133            )?;
2134        }
2135    }
2136
2137    for repository in mem::take(&mut project.repositories) {
2138        for update in split_repository_update(repository) {
2139            session.peer.send(session.connection_id, update)?;
2140        }
2141    }
2142
2143    for language_server in &project.language_servers {
2144        session.peer.send(
2145            session.connection_id,
2146            proto::UpdateLanguageServer {
2147                project_id: project_id.to_proto(),
2148                server_name: Some(language_server.server.name.clone()),
2149                language_server_id: language_server.server.id,
2150                variant: Some(
2151                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2152                        proto::LspDiskBasedDiagnosticsUpdated {},
2153                    ),
2154                ),
2155            },
2156        )?;
2157    }
2158
2159    Ok(())
2160}
2161
2162/// Leave someone elses shared project.
2163async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2164    let sender_id = session.connection_id;
2165    let project_id = ProjectId::from_proto(request.project_id);
2166    let db = session.db().await;
2167
2168    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2169    tracing::info!(
2170        %project_id,
2171        "leave project"
2172    );
2173
2174    project_left(project, &session);
2175    if let Some(room) = room {
2176        room_updated(room, &session.peer);
2177    }
2178
2179    Ok(())
2180}
2181
2182/// Updates other participants with changes to the project
2183async fn update_project(
2184    request: proto::UpdateProject,
2185    response: Response<proto::UpdateProject>,
2186    session: MessageContext,
2187) -> Result<()> {
2188    let project_id = ProjectId::from_proto(request.project_id);
2189    let (room, guest_connection_ids) = &*session
2190        .db()
2191        .await
2192        .update_project(project_id, session.connection_id, &request.worktrees)
2193        .await?;
2194    broadcast(
2195        Some(session.connection_id),
2196        guest_connection_ids.iter().copied(),
2197        |connection_id| {
2198            session
2199                .peer
2200                .forward_send(session.connection_id, connection_id, request.clone())
2201        },
2202    );
2203    if let Some(room) = room {
2204        room_updated(room, &session.peer);
2205    }
2206    response.send(proto::Ack {})?;
2207
2208    Ok(())
2209}
2210
2211/// Updates other participants with changes to the worktree
2212async fn update_worktree(
2213    request: proto::UpdateWorktree,
2214    response: Response<proto::UpdateWorktree>,
2215    session: MessageContext,
2216) -> Result<()> {
2217    let guest_connection_ids = session
2218        .db()
2219        .await
2220        .update_worktree(&request, session.connection_id)
2221        .await?;
2222
2223    broadcast(
2224        Some(session.connection_id),
2225        guest_connection_ids.iter().copied(),
2226        |connection_id| {
2227            session
2228                .peer
2229                .forward_send(session.connection_id, connection_id, request.clone())
2230        },
2231    );
2232    response.send(proto::Ack {})?;
2233    Ok(())
2234}
2235
2236async fn update_repository(
2237    request: proto::UpdateRepository,
2238    response: Response<proto::UpdateRepository>,
2239    session: MessageContext,
2240) -> Result<()> {
2241    let guest_connection_ids = session
2242        .db()
2243        .await
2244        .update_repository(&request, session.connection_id)
2245        .await?;
2246
2247    broadcast(
2248        Some(session.connection_id),
2249        guest_connection_ids.iter().copied(),
2250        |connection_id| {
2251            session
2252                .peer
2253                .forward_send(session.connection_id, connection_id, request.clone())
2254        },
2255    );
2256    response.send(proto::Ack {})?;
2257    Ok(())
2258}
2259
2260async fn remove_repository(
2261    request: proto::RemoveRepository,
2262    response: Response<proto::RemoveRepository>,
2263    session: MessageContext,
2264) -> Result<()> {
2265    let guest_connection_ids = session
2266        .db()
2267        .await
2268        .remove_repository(&request, session.connection_id)
2269        .await?;
2270
2271    broadcast(
2272        Some(session.connection_id),
2273        guest_connection_ids.iter().copied(),
2274        |connection_id| {
2275            session
2276                .peer
2277                .forward_send(session.connection_id, connection_id, request.clone())
2278        },
2279    );
2280    response.send(proto::Ack {})?;
2281    Ok(())
2282}
2283
2284/// Updates other participants with changes to the diagnostics
2285async fn update_diagnostic_summary(
2286    message: proto::UpdateDiagnosticSummary,
2287    session: MessageContext,
2288) -> Result<()> {
2289    let guest_connection_ids = session
2290        .db()
2291        .await
2292        .update_diagnostic_summary(&message, session.connection_id)
2293        .await?;
2294
2295    broadcast(
2296        Some(session.connection_id),
2297        guest_connection_ids.iter().copied(),
2298        |connection_id| {
2299            session
2300                .peer
2301                .forward_send(session.connection_id, connection_id, message.clone())
2302        },
2303    );
2304
2305    Ok(())
2306}
2307
2308/// Updates other participants with changes to the worktree settings
2309async fn update_worktree_settings(
2310    message: proto::UpdateWorktreeSettings,
2311    session: MessageContext,
2312) -> Result<()> {
2313    let guest_connection_ids = session
2314        .db()
2315        .await
2316        .update_worktree_settings(&message, session.connection_id)
2317        .await?;
2318
2319    broadcast(
2320        Some(session.connection_id),
2321        guest_connection_ids.iter().copied(),
2322        |connection_id| {
2323            session
2324                .peer
2325                .forward_send(session.connection_id, connection_id, message.clone())
2326        },
2327    );
2328
2329    Ok(())
2330}
2331
2332/// Notify other participants that a language server has started.
2333async fn start_language_server(
2334    request: proto::StartLanguageServer,
2335    session: MessageContext,
2336) -> Result<()> {
2337    let guest_connection_ids = session
2338        .db()
2339        .await
2340        .start_language_server(&request, session.connection_id)
2341        .await?;
2342
2343    broadcast(
2344        Some(session.connection_id),
2345        guest_connection_ids.iter().copied(),
2346        |connection_id| {
2347            session
2348                .peer
2349                .forward_send(session.connection_id, connection_id, request.clone())
2350        },
2351    );
2352    Ok(())
2353}
2354
2355/// Notify other participants that a language server has changed.
2356async fn update_language_server(
2357    request: proto::UpdateLanguageServer,
2358    session: MessageContext,
2359) -> Result<()> {
2360    let project_id = ProjectId::from_proto(request.project_id);
2361    let db = session.db().await;
2362
2363    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2364    {
2365        if let Some(capabilities) = update.capabilities.clone() {
2366            db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2367                .await?;
2368        }
2369    }
2370
2371    let project_connection_ids = db
2372        .project_connection_ids(project_id, session.connection_id, true)
2373        .await?;
2374    broadcast(
2375        Some(session.connection_id),
2376        project_connection_ids.iter().copied(),
2377        |connection_id| {
2378            session
2379                .peer
2380                .forward_send(session.connection_id, connection_id, request.clone())
2381        },
2382    );
2383    Ok(())
2384}
2385
2386/// forward a project request to the host. These requests should be read only
2387/// as guests are allowed to send them.
2388async fn forward_read_only_project_request<T>(
2389    request: T,
2390    response: Response<T>,
2391    session: MessageContext,
2392) -> Result<()>
2393where
2394    T: EntityMessage + RequestMessage,
2395{
2396    let project_id = ProjectId::from_proto(request.remote_entity_id());
2397    let host_connection_id = session
2398        .db()
2399        .await
2400        .host_for_read_only_project_request(project_id, session.connection_id)
2401        .await?;
2402    let payload = session.forward_request(host_connection_id, request).await?;
2403    response.send(payload)?;
2404    Ok(())
2405}
2406
2407/// forward a project request to the host. These requests are disallowed
2408/// for guests.
2409async fn forward_mutating_project_request<T>(
2410    request: T,
2411    response: Response<T>,
2412    session: MessageContext,
2413) -> Result<()>
2414where
2415    T: EntityMessage + RequestMessage,
2416{
2417    let project_id = ProjectId::from_proto(request.remote_entity_id());
2418
2419    let host_connection_id = session
2420        .db()
2421        .await
2422        .host_for_mutating_project_request(project_id, session.connection_id)
2423        .await?;
2424    let payload = session.forward_request(host_connection_id, request).await?;
2425    response.send(payload)?;
2426    Ok(())
2427}
2428
2429async fn multi_lsp_query(
2430    request: MultiLspQuery,
2431    response: Response<MultiLspQuery>,
2432    session: MessageContext,
2433) -> Result<()> {
2434    tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2435    tracing::info!("multi_lsp_query message received");
2436    forward_mutating_project_request(request, response, session).await
2437}
2438
2439/// Notify other participants that a new buffer has been created
2440async fn create_buffer_for_peer(
2441    request: proto::CreateBufferForPeer,
2442    session: MessageContext,
2443) -> Result<()> {
2444    session
2445        .db()
2446        .await
2447        .check_user_is_project_host(
2448            ProjectId::from_proto(request.project_id),
2449            session.connection_id,
2450        )
2451        .await?;
2452    let peer_id = request.peer_id.context("invalid peer id")?;
2453    session
2454        .peer
2455        .forward_send(session.connection_id, peer_id.into(), request)?;
2456    Ok(())
2457}
2458
2459/// Notify other participants that a buffer has been updated. This is
2460/// allowed for guests as long as the update is limited to selections.
2461async fn update_buffer(
2462    request: proto::UpdateBuffer,
2463    response: Response<proto::UpdateBuffer>,
2464    session: MessageContext,
2465) -> Result<()> {
2466    let project_id = ProjectId::from_proto(request.project_id);
2467    let mut capability = Capability::ReadOnly;
2468
2469    for op in request.operations.iter() {
2470        match op.variant {
2471            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2472            Some(_) => capability = Capability::ReadWrite,
2473        }
2474    }
2475
2476    let host = {
2477        let guard = session
2478            .db()
2479            .await
2480            .connections_for_buffer_update(project_id, session.connection_id, capability)
2481            .await?;
2482
2483        let (host, guests) = &*guard;
2484
2485        broadcast(
2486            Some(session.connection_id),
2487            guests.clone(),
2488            |connection_id| {
2489                session
2490                    .peer
2491                    .forward_send(session.connection_id, connection_id, request.clone())
2492            },
2493        );
2494
2495        *host
2496    };
2497
2498    if host != session.connection_id {
2499        session.forward_request(host, request.clone()).await?;
2500    }
2501
2502    response.send(proto::Ack {})?;
2503    Ok(())
2504}
2505
2506async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2507    let project_id = ProjectId::from_proto(message.project_id);
2508
2509    let operation = message.operation.as_ref().context("invalid operation")?;
2510    let capability = match operation.variant.as_ref() {
2511        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2512            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2513                match buffer_op.variant {
2514                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2515                        Capability::ReadOnly
2516                    }
2517                    _ => Capability::ReadWrite,
2518                }
2519            } else {
2520                Capability::ReadWrite
2521            }
2522        }
2523        Some(_) => Capability::ReadWrite,
2524        None => Capability::ReadOnly,
2525    };
2526
2527    let guard = session
2528        .db()
2529        .await
2530        .connections_for_buffer_update(project_id, session.connection_id, capability)
2531        .await?;
2532
2533    let (host, guests) = &*guard;
2534
2535    broadcast(
2536        Some(session.connection_id),
2537        guests.iter().chain([host]).copied(),
2538        |connection_id| {
2539            session
2540                .peer
2541                .forward_send(session.connection_id, connection_id, message.clone())
2542        },
2543    );
2544
2545    Ok(())
2546}
2547
2548/// Notify other participants that a project has been updated.
2549async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2550    request: T,
2551    session: MessageContext,
2552) -> Result<()> {
2553    let project_id = ProjectId::from_proto(request.remote_entity_id());
2554    let project_connection_ids = session
2555        .db()
2556        .await
2557        .project_connection_ids(project_id, session.connection_id, false)
2558        .await?;
2559
2560    broadcast(
2561        Some(session.connection_id),
2562        project_connection_ids.iter().copied(),
2563        |connection_id| {
2564            session
2565                .peer
2566                .forward_send(session.connection_id, connection_id, request.clone())
2567        },
2568    );
2569    Ok(())
2570}
2571
2572/// Start following another user in a call.
2573async fn follow(
2574    request: proto::Follow,
2575    response: Response<proto::Follow>,
2576    session: MessageContext,
2577) -> Result<()> {
2578    let room_id = RoomId::from_proto(request.room_id);
2579    let project_id = request.project_id.map(ProjectId::from_proto);
2580    let leader_id = request.leader_id.context("invalid leader id")?.into();
2581    let follower_id = session.connection_id;
2582
2583    session
2584        .db()
2585        .await
2586        .check_room_participants(room_id, leader_id, session.connection_id)
2587        .await?;
2588
2589    let response_payload = session.forward_request(leader_id, request).await?;
2590    response.send(response_payload)?;
2591
2592    if let Some(project_id) = project_id {
2593        let room = session
2594            .db()
2595            .await
2596            .follow(room_id, project_id, leader_id, follower_id)
2597            .await?;
2598        room_updated(&room, &session.peer);
2599    }
2600
2601    Ok(())
2602}
2603
2604/// Stop following another user in a call.
2605async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2606    let room_id = RoomId::from_proto(request.room_id);
2607    let project_id = request.project_id.map(ProjectId::from_proto);
2608    let leader_id = request.leader_id.context("invalid leader id")?.into();
2609    let follower_id = session.connection_id;
2610
2611    session
2612        .db()
2613        .await
2614        .check_room_participants(room_id, leader_id, session.connection_id)
2615        .await?;
2616
2617    session
2618        .peer
2619        .forward_send(session.connection_id, leader_id, request)?;
2620
2621    if let Some(project_id) = project_id {
2622        let room = session
2623            .db()
2624            .await
2625            .unfollow(room_id, project_id, leader_id, follower_id)
2626            .await?;
2627        room_updated(&room, &session.peer);
2628    }
2629
2630    Ok(())
2631}
2632
2633/// Notify everyone following you of your current location.
2634async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2635    let room_id = RoomId::from_proto(request.room_id);
2636    let database = session.db.lock().await;
2637
2638    let connection_ids = if let Some(project_id) = request.project_id {
2639        let project_id = ProjectId::from_proto(project_id);
2640        database
2641            .project_connection_ids(project_id, session.connection_id, true)
2642            .await?
2643    } else {
2644        database
2645            .room_connection_ids(room_id, session.connection_id)
2646            .await?
2647    };
2648
2649    // For now, don't send view update messages back to that view's current leader.
2650    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2651        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2652        _ => None,
2653    });
2654
2655    for connection_id in connection_ids.iter().cloned() {
2656        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2657            session
2658                .peer
2659                .forward_send(session.connection_id, connection_id, request.clone())?;
2660        }
2661    }
2662    Ok(())
2663}
2664
2665/// Get public data about users.
2666async fn get_users(
2667    request: proto::GetUsers,
2668    response: Response<proto::GetUsers>,
2669    session: MessageContext,
2670) -> Result<()> {
2671    let user_ids = request
2672        .user_ids
2673        .into_iter()
2674        .map(UserId::from_proto)
2675        .collect();
2676    let users = session
2677        .db()
2678        .await
2679        .get_users_by_ids(user_ids)
2680        .await?
2681        .into_iter()
2682        .map(|user| proto::User {
2683            id: user.id.to_proto(),
2684            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2685            github_login: user.github_login,
2686            name: user.name,
2687        })
2688        .collect();
2689    response.send(proto::UsersResponse { users })?;
2690    Ok(())
2691}
2692
2693/// Search for users (to invite) buy Github login
2694async fn fuzzy_search_users(
2695    request: proto::FuzzySearchUsers,
2696    response: Response<proto::FuzzySearchUsers>,
2697    session: MessageContext,
2698) -> Result<()> {
2699    let query = request.query;
2700    let users = match query.len() {
2701        0 => vec![],
2702        1 | 2 => session
2703            .db()
2704            .await
2705            .get_user_by_github_login(&query)
2706            .await?
2707            .into_iter()
2708            .collect(),
2709        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2710    };
2711    let users = users
2712        .into_iter()
2713        .filter(|user| user.id != session.user_id())
2714        .map(|user| proto::User {
2715            id: user.id.to_proto(),
2716            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2717            github_login: user.github_login,
2718            name: user.name,
2719        })
2720        .collect();
2721    response.send(proto::UsersResponse { users })?;
2722    Ok(())
2723}
2724
2725/// Send a contact request to another user.
2726async fn request_contact(
2727    request: proto::RequestContact,
2728    response: Response<proto::RequestContact>,
2729    session: MessageContext,
2730) -> Result<()> {
2731    let requester_id = session.user_id();
2732    let responder_id = UserId::from_proto(request.responder_id);
2733    if requester_id == responder_id {
2734        return Err(anyhow!("cannot add yourself as a contact"))?;
2735    }
2736
2737    let notifications = session
2738        .db()
2739        .await
2740        .send_contact_request(requester_id, responder_id)
2741        .await?;
2742
2743    // Update outgoing contact requests of requester
2744    let mut update = proto::UpdateContacts::default();
2745    update.outgoing_requests.push(responder_id.to_proto());
2746    for connection_id in session
2747        .connection_pool()
2748        .await
2749        .user_connection_ids(requester_id)
2750    {
2751        session.peer.send(connection_id, update.clone())?;
2752    }
2753
2754    // Update incoming contact requests of responder
2755    let mut update = proto::UpdateContacts::default();
2756    update
2757        .incoming_requests
2758        .push(proto::IncomingContactRequest {
2759            requester_id: requester_id.to_proto(),
2760        });
2761    let connection_pool = session.connection_pool().await;
2762    for connection_id in connection_pool.user_connection_ids(responder_id) {
2763        session.peer.send(connection_id, update.clone())?;
2764    }
2765
2766    send_notifications(&connection_pool, &session.peer, notifications);
2767
2768    response.send(proto::Ack {})?;
2769    Ok(())
2770}
2771
2772/// Accept or decline a contact request
2773async fn respond_to_contact_request(
2774    request: proto::RespondToContactRequest,
2775    response: Response<proto::RespondToContactRequest>,
2776    session: MessageContext,
2777) -> Result<()> {
2778    let responder_id = session.user_id();
2779    let requester_id = UserId::from_proto(request.requester_id);
2780    let db = session.db().await;
2781    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2782        db.dismiss_contact_notification(responder_id, requester_id)
2783            .await?;
2784    } else {
2785        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2786
2787        let notifications = db
2788            .respond_to_contact_request(responder_id, requester_id, accept)
2789            .await?;
2790        let requester_busy = db.is_user_busy(requester_id).await?;
2791        let responder_busy = db.is_user_busy(responder_id).await?;
2792
2793        let pool = session.connection_pool().await;
2794        // Update responder with new contact
2795        let mut update = proto::UpdateContacts::default();
2796        if accept {
2797            update
2798                .contacts
2799                .push(contact_for_user(requester_id, requester_busy, &pool));
2800        }
2801        update
2802            .remove_incoming_requests
2803            .push(requester_id.to_proto());
2804        for connection_id in pool.user_connection_ids(responder_id) {
2805            session.peer.send(connection_id, update.clone())?;
2806        }
2807
2808        // Update requester with new contact
2809        let mut update = proto::UpdateContacts::default();
2810        if accept {
2811            update
2812                .contacts
2813                .push(contact_for_user(responder_id, responder_busy, &pool));
2814        }
2815        update
2816            .remove_outgoing_requests
2817            .push(responder_id.to_proto());
2818
2819        for connection_id in pool.user_connection_ids(requester_id) {
2820            session.peer.send(connection_id, update.clone())?;
2821        }
2822
2823        send_notifications(&pool, &session.peer, notifications);
2824    }
2825
2826    response.send(proto::Ack {})?;
2827    Ok(())
2828}
2829
2830/// Remove a contact.
2831async fn remove_contact(
2832    request: proto::RemoveContact,
2833    response: Response<proto::RemoveContact>,
2834    session: MessageContext,
2835) -> Result<()> {
2836    let requester_id = session.user_id();
2837    let responder_id = UserId::from_proto(request.user_id);
2838    let db = session.db().await;
2839    let (contact_accepted, deleted_notification_id) =
2840        db.remove_contact(requester_id, responder_id).await?;
2841
2842    let pool = session.connection_pool().await;
2843    // Update outgoing contact requests of requester
2844    let mut update = proto::UpdateContacts::default();
2845    if contact_accepted {
2846        update.remove_contacts.push(responder_id.to_proto());
2847    } else {
2848        update
2849            .remove_outgoing_requests
2850            .push(responder_id.to_proto());
2851    }
2852    for connection_id in pool.user_connection_ids(requester_id) {
2853        session.peer.send(connection_id, update.clone())?;
2854    }
2855
2856    // Update incoming contact requests of responder
2857    let mut update = proto::UpdateContacts::default();
2858    if contact_accepted {
2859        update.remove_contacts.push(requester_id.to_proto());
2860    } else {
2861        update
2862            .remove_incoming_requests
2863            .push(requester_id.to_proto());
2864    }
2865    for connection_id in pool.user_connection_ids(responder_id) {
2866        session.peer.send(connection_id, update.clone())?;
2867        if let Some(notification_id) = deleted_notification_id {
2868            session.peer.send(
2869                connection_id,
2870                proto::DeleteNotification {
2871                    notification_id: notification_id.to_proto(),
2872                },
2873            )?;
2874        }
2875    }
2876
2877    response.send(proto::Ack {})?;
2878    Ok(())
2879}
2880
2881fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2882    version.0.minor() < 139
2883}
2884
2885async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2886    if is_staff {
2887        return Ok(proto::Plan::ZedPro);
2888    }
2889
2890    let subscription = db.get_active_billing_subscription(user_id).await?;
2891    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2892
2893    let plan = if let Some(subscription_kind) = subscription_kind {
2894        match subscription_kind {
2895            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2896            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2897            SubscriptionKind::ZedFree => proto::Plan::Free,
2898        }
2899    } else {
2900        proto::Plan::Free
2901    };
2902
2903    Ok(plan)
2904}
2905
2906async fn make_update_user_plan_message(
2907    user: &User,
2908    is_staff: bool,
2909    db: &Arc<Database>,
2910    llm_db: Option<Arc<LlmDatabase>>,
2911) -> Result<proto::UpdateUserPlan> {
2912    let feature_flags = db.get_user_flags(user.id).await?;
2913    let plan = current_plan(db, user.id, is_staff).await?;
2914    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2915    let billing_preferences = db.get_billing_preferences(user.id).await?;
2916
2917    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2918        let subscription = db.get_active_billing_subscription(user.id).await?;
2919
2920        let subscription_period =
2921            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2922
2923        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2924            llm_db
2925                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2926                .await?
2927        } else {
2928            None
2929        };
2930
2931        (subscription_period, usage)
2932    } else {
2933        (None, None)
2934    };
2935
2936    let bypass_account_age_check = feature_flags
2937        .iter()
2938        .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2939    let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2940        && !bypass_account_age_check
2941        && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2942
2943    Ok(proto::UpdateUserPlan {
2944        plan: plan.into(),
2945        trial_started_at: billing_customer
2946            .as_ref()
2947            .and_then(|billing_customer| billing_customer.trial_started_at)
2948            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2949        is_usage_based_billing_enabled: if is_staff {
2950            Some(true)
2951        } else {
2952            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2953        },
2954        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2955            proto::SubscriptionPeriod {
2956                started_at: started_at.timestamp() as u64,
2957                ended_at: ended_at.timestamp() as u64,
2958            }
2959        }),
2960        account_too_young: Some(account_too_young),
2961        has_overdue_invoices: billing_customer
2962            .map(|billing_customer| billing_customer.has_overdue_invoices),
2963        usage: Some(
2964            usage
2965                .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2966                .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2967        ),
2968    })
2969}
2970
2971fn model_requests_limit(
2972    plan: cloud_llm_client::Plan,
2973    feature_flags: &Vec<String>,
2974) -> cloud_llm_client::UsageLimit {
2975    match plan.model_requests_limit() {
2976        cloud_llm_client::UsageLimit::Limited(limit) => {
2977            let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2978                && feature_flags
2979                    .iter()
2980                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2981            {
2982                1_000
2983            } else {
2984                limit
2985            };
2986
2987            cloud_llm_client::UsageLimit::Limited(limit)
2988        }
2989        cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2990    }
2991}
2992
2993fn subscription_usage_to_proto(
2994    plan: proto::Plan,
2995    usage: crate::llm::db::subscription_usage::Model,
2996    feature_flags: &Vec<String>,
2997) -> proto::SubscriptionUsage {
2998    let plan = match plan {
2999        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
3000        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
3001        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
3002    };
3003
3004    proto::SubscriptionUsage {
3005        model_requests_usage_amount: usage.model_requests as u32,
3006        model_requests_usage_limit: Some(proto::UsageLimit {
3007            variant: Some(match model_requests_limit(plan, feature_flags) {
3008                cloud_llm_client::UsageLimit::Limited(limit) => {
3009                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3010                        limit: limit as u32,
3011                    })
3012                }
3013                cloud_llm_client::UsageLimit::Unlimited => {
3014                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3015                }
3016            }),
3017        }),
3018        edit_predictions_usage_amount: usage.edit_predictions as u32,
3019        edit_predictions_usage_limit: Some(proto::UsageLimit {
3020            variant: Some(match plan.edit_predictions_limit() {
3021                cloud_llm_client::UsageLimit::Limited(limit) => {
3022                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3023                        limit: limit as u32,
3024                    })
3025                }
3026                cloud_llm_client::UsageLimit::Unlimited => {
3027                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3028                }
3029            }),
3030        }),
3031    }
3032}
3033
3034fn make_default_subscription_usage(
3035    plan: proto::Plan,
3036    feature_flags: &Vec<String>,
3037) -> proto::SubscriptionUsage {
3038    let plan = match plan {
3039        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
3040        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
3041        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
3042    };
3043
3044    proto::SubscriptionUsage {
3045        model_requests_usage_amount: 0,
3046        model_requests_usage_limit: Some(proto::UsageLimit {
3047            variant: Some(match model_requests_limit(plan, feature_flags) {
3048                cloud_llm_client::UsageLimit::Limited(limit) => {
3049                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3050                        limit: limit as u32,
3051                    })
3052                }
3053                cloud_llm_client::UsageLimit::Unlimited => {
3054                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3055                }
3056            }),
3057        }),
3058        edit_predictions_usage_amount: 0,
3059        edit_predictions_usage_limit: Some(proto::UsageLimit {
3060            variant: Some(match plan.edit_predictions_limit() {
3061                cloud_llm_client::UsageLimit::Limited(limit) => {
3062                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3063                        limit: limit as u32,
3064                    })
3065                }
3066                cloud_llm_client::UsageLimit::Unlimited => {
3067                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3068                }
3069            }),
3070        }),
3071    }
3072}
3073
3074async fn update_user_plan(session: &Session) -> Result<()> {
3075    let db = session.db().await;
3076
3077    let update_user_plan = make_update_user_plan_message(
3078        session.principal.user(),
3079        session.is_staff(),
3080        &db.0,
3081        session.app_state.llm_db.clone(),
3082    )
3083    .await?;
3084
3085    session
3086        .peer
3087        .send(session.connection_id, update_user_plan)
3088        .trace_err();
3089
3090    Ok(())
3091}
3092
3093async fn subscribe_to_channels(
3094    _: proto::SubscribeToChannels,
3095    session: MessageContext,
3096) -> Result<()> {
3097    subscribe_user_to_channels(session.user_id(), &session).await?;
3098    Ok(())
3099}
3100
3101async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3102    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3103    let mut pool = session.connection_pool().await;
3104    for membership in &channels_for_user.channel_memberships {
3105        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3106    }
3107    session.peer.send(
3108        session.connection_id,
3109        build_update_user_channels(&channels_for_user),
3110    )?;
3111    session.peer.send(
3112        session.connection_id,
3113        build_channels_update(channels_for_user),
3114    )?;
3115    Ok(())
3116}
3117
3118/// Creates a new channel.
3119async fn create_channel(
3120    request: proto::CreateChannel,
3121    response: Response<proto::CreateChannel>,
3122    session: MessageContext,
3123) -> Result<()> {
3124    let db = session.db().await;
3125
3126    let parent_id = request.parent_id.map(ChannelId::from_proto);
3127    let (channel, membership) = db
3128        .create_channel(&request.name, parent_id, session.user_id())
3129        .await?;
3130
3131    let root_id = channel.root_id();
3132    let channel = Channel::from_model(channel);
3133
3134    response.send(proto::CreateChannelResponse {
3135        channel: Some(channel.to_proto()),
3136        parent_id: request.parent_id,
3137    })?;
3138
3139    let mut connection_pool = session.connection_pool().await;
3140    if let Some(membership) = membership {
3141        connection_pool.subscribe_to_channel(
3142            membership.user_id,
3143            membership.channel_id,
3144            membership.role,
3145        );
3146        let update = proto::UpdateUserChannels {
3147            channel_memberships: vec![proto::ChannelMembership {
3148                channel_id: membership.channel_id.to_proto(),
3149                role: membership.role.into(),
3150            }],
3151            ..Default::default()
3152        };
3153        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3154            session.peer.send(connection_id, update.clone())?;
3155        }
3156    }
3157
3158    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3159        if !role.can_see_channel(channel.visibility) {
3160            continue;
3161        }
3162
3163        let update = proto::UpdateChannels {
3164            channels: vec![channel.to_proto()],
3165            ..Default::default()
3166        };
3167        session.peer.send(connection_id, update.clone())?;
3168    }
3169
3170    Ok(())
3171}
3172
3173/// Delete a channel
3174async fn delete_channel(
3175    request: proto::DeleteChannel,
3176    response: Response<proto::DeleteChannel>,
3177    session: MessageContext,
3178) -> Result<()> {
3179    let db = session.db().await;
3180
3181    let channel_id = request.channel_id;
3182    let (root_channel, removed_channels) = db
3183        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3184        .await?;
3185    response.send(proto::Ack {})?;
3186
3187    // Notify members of removed channels
3188    let mut update = proto::UpdateChannels::default();
3189    update
3190        .delete_channels
3191        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3192
3193    let connection_pool = session.connection_pool().await;
3194    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3195        session.peer.send(connection_id, update.clone())?;
3196    }
3197
3198    Ok(())
3199}
3200
3201/// Invite someone to join a channel.
3202async fn invite_channel_member(
3203    request: proto::InviteChannelMember,
3204    response: Response<proto::InviteChannelMember>,
3205    session: MessageContext,
3206) -> Result<()> {
3207    let db = session.db().await;
3208    let channel_id = ChannelId::from_proto(request.channel_id);
3209    let invitee_id = UserId::from_proto(request.user_id);
3210    let InviteMemberResult {
3211        channel,
3212        notifications,
3213    } = db
3214        .invite_channel_member(
3215            channel_id,
3216            invitee_id,
3217            session.user_id(),
3218            request.role().into(),
3219        )
3220        .await?;
3221
3222    let update = proto::UpdateChannels {
3223        channel_invitations: vec![channel.to_proto()],
3224        ..Default::default()
3225    };
3226
3227    let connection_pool = session.connection_pool().await;
3228    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3229        session.peer.send(connection_id, update.clone())?;
3230    }
3231
3232    send_notifications(&connection_pool, &session.peer, notifications);
3233
3234    response.send(proto::Ack {})?;
3235    Ok(())
3236}
3237
3238/// remove someone from a channel
3239async fn remove_channel_member(
3240    request: proto::RemoveChannelMember,
3241    response: Response<proto::RemoveChannelMember>,
3242    session: MessageContext,
3243) -> Result<()> {
3244    let db = session.db().await;
3245    let channel_id = ChannelId::from_proto(request.channel_id);
3246    let member_id = UserId::from_proto(request.user_id);
3247
3248    let RemoveChannelMemberResult {
3249        membership_update,
3250        notification_id,
3251    } = db
3252        .remove_channel_member(channel_id, member_id, session.user_id())
3253        .await?;
3254
3255    let mut connection_pool = session.connection_pool().await;
3256    notify_membership_updated(
3257        &mut connection_pool,
3258        membership_update,
3259        member_id,
3260        &session.peer,
3261    );
3262    for connection_id in connection_pool.user_connection_ids(member_id) {
3263        if let Some(notification_id) = notification_id {
3264            session
3265                .peer
3266                .send(
3267                    connection_id,
3268                    proto::DeleteNotification {
3269                        notification_id: notification_id.to_proto(),
3270                    },
3271                )
3272                .trace_err();
3273        }
3274    }
3275
3276    response.send(proto::Ack {})?;
3277    Ok(())
3278}
3279
3280/// Toggle the channel between public and private.
3281/// Care is taken to maintain the invariant that public channels only descend from public channels,
3282/// (though members-only channels can appear at any point in the hierarchy).
3283async fn set_channel_visibility(
3284    request: proto::SetChannelVisibility,
3285    response: Response<proto::SetChannelVisibility>,
3286    session: MessageContext,
3287) -> Result<()> {
3288    let db = session.db().await;
3289    let channel_id = ChannelId::from_proto(request.channel_id);
3290    let visibility = request.visibility().into();
3291
3292    let channel_model = db
3293        .set_channel_visibility(channel_id, visibility, session.user_id())
3294        .await?;
3295    let root_id = channel_model.root_id();
3296    let channel = Channel::from_model(channel_model);
3297
3298    let mut connection_pool = session.connection_pool().await;
3299    for (user_id, role) in connection_pool
3300        .channel_user_ids(root_id)
3301        .collect::<Vec<_>>()
3302        .into_iter()
3303    {
3304        let update = if role.can_see_channel(channel.visibility) {
3305            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3306            proto::UpdateChannels {
3307                channels: vec![channel.to_proto()],
3308                ..Default::default()
3309            }
3310        } else {
3311            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3312            proto::UpdateChannels {
3313                delete_channels: vec![channel.id.to_proto()],
3314                ..Default::default()
3315            }
3316        };
3317
3318        for connection_id in connection_pool.user_connection_ids(user_id) {
3319            session.peer.send(connection_id, update.clone())?;
3320        }
3321    }
3322
3323    response.send(proto::Ack {})?;
3324    Ok(())
3325}
3326
3327/// Alter the role for a user in the channel.
3328async fn set_channel_member_role(
3329    request: proto::SetChannelMemberRole,
3330    response: Response<proto::SetChannelMemberRole>,
3331    session: MessageContext,
3332) -> Result<()> {
3333    let db = session.db().await;
3334    let channel_id = ChannelId::from_proto(request.channel_id);
3335    let member_id = UserId::from_proto(request.user_id);
3336    let result = db
3337        .set_channel_member_role(
3338            channel_id,
3339            session.user_id(),
3340            member_id,
3341            request.role().into(),
3342        )
3343        .await?;
3344
3345    match result {
3346        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3347            let mut connection_pool = session.connection_pool().await;
3348            notify_membership_updated(
3349                &mut connection_pool,
3350                membership_update,
3351                member_id,
3352                &session.peer,
3353            )
3354        }
3355        db::SetMemberRoleResult::InviteUpdated(channel) => {
3356            let update = proto::UpdateChannels {
3357                channel_invitations: vec![channel.to_proto()],
3358                ..Default::default()
3359            };
3360
3361            for connection_id in session
3362                .connection_pool()
3363                .await
3364                .user_connection_ids(member_id)
3365            {
3366                session.peer.send(connection_id, update.clone())?;
3367            }
3368        }
3369    }
3370
3371    response.send(proto::Ack {})?;
3372    Ok(())
3373}
3374
3375/// Change the name of a channel
3376async fn rename_channel(
3377    request: proto::RenameChannel,
3378    response: Response<proto::RenameChannel>,
3379    session: MessageContext,
3380) -> Result<()> {
3381    let db = session.db().await;
3382    let channel_id = ChannelId::from_proto(request.channel_id);
3383    let channel_model = db
3384        .rename_channel(channel_id, session.user_id(), &request.name)
3385        .await?;
3386    let root_id = channel_model.root_id();
3387    let channel = Channel::from_model(channel_model);
3388
3389    response.send(proto::RenameChannelResponse {
3390        channel: Some(channel.to_proto()),
3391    })?;
3392
3393    let connection_pool = session.connection_pool().await;
3394    let update = proto::UpdateChannels {
3395        channels: vec![channel.to_proto()],
3396        ..Default::default()
3397    };
3398    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3399        if role.can_see_channel(channel.visibility) {
3400            session.peer.send(connection_id, update.clone())?;
3401        }
3402    }
3403
3404    Ok(())
3405}
3406
3407/// Move a channel to a new parent.
3408async fn move_channel(
3409    request: proto::MoveChannel,
3410    response: Response<proto::MoveChannel>,
3411    session: MessageContext,
3412) -> Result<()> {
3413    let channel_id = ChannelId::from_proto(request.channel_id);
3414    let to = ChannelId::from_proto(request.to);
3415
3416    let (root_id, channels) = session
3417        .db()
3418        .await
3419        .move_channel(channel_id, to, session.user_id())
3420        .await?;
3421
3422    let connection_pool = session.connection_pool().await;
3423    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3424        let channels = channels
3425            .iter()
3426            .filter_map(|channel| {
3427                if role.can_see_channel(channel.visibility) {
3428                    Some(channel.to_proto())
3429                } else {
3430                    None
3431                }
3432            })
3433            .collect::<Vec<_>>();
3434        if channels.is_empty() {
3435            continue;
3436        }
3437
3438        let update = proto::UpdateChannels {
3439            channels,
3440            ..Default::default()
3441        };
3442
3443        session.peer.send(connection_id, update.clone())?;
3444    }
3445
3446    response.send(Ack {})?;
3447    Ok(())
3448}
3449
3450async fn reorder_channel(
3451    request: proto::ReorderChannel,
3452    response: Response<proto::ReorderChannel>,
3453    session: MessageContext,
3454) -> Result<()> {
3455    let channel_id = ChannelId::from_proto(request.channel_id);
3456    let direction = request.direction();
3457
3458    let updated_channels = session
3459        .db()
3460        .await
3461        .reorder_channel(channel_id, direction, session.user_id())
3462        .await?;
3463
3464    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3465        let connection_pool = session.connection_pool().await;
3466        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3467            let channels = updated_channels
3468                .iter()
3469                .filter_map(|channel| {
3470                    if role.can_see_channel(channel.visibility) {
3471                        Some(channel.to_proto())
3472                    } else {
3473                        None
3474                    }
3475                })
3476                .collect::<Vec<_>>();
3477
3478            if channels.is_empty() {
3479                continue;
3480            }
3481
3482            let update = proto::UpdateChannels {
3483                channels,
3484                ..Default::default()
3485            };
3486
3487            session.peer.send(connection_id, update.clone())?;
3488        }
3489    }
3490
3491    response.send(Ack {})?;
3492    Ok(())
3493}
3494
3495/// Get the list of channel members
3496async fn get_channel_members(
3497    request: proto::GetChannelMembers,
3498    response: Response<proto::GetChannelMembers>,
3499    session: MessageContext,
3500) -> Result<()> {
3501    let db = session.db().await;
3502    let channel_id = ChannelId::from_proto(request.channel_id);
3503    let limit = if request.limit == 0 {
3504        u16::MAX as u64
3505    } else {
3506        request.limit
3507    };
3508    let (members, users) = db
3509        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3510        .await?;
3511    response.send(proto::GetChannelMembersResponse { members, users })?;
3512    Ok(())
3513}
3514
3515/// Accept or decline a channel invitation.
3516async fn respond_to_channel_invite(
3517    request: proto::RespondToChannelInvite,
3518    response: Response<proto::RespondToChannelInvite>,
3519    session: MessageContext,
3520) -> Result<()> {
3521    let db = session.db().await;
3522    let channel_id = ChannelId::from_proto(request.channel_id);
3523    let RespondToChannelInvite {
3524        membership_update,
3525        notifications,
3526    } = db
3527        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3528        .await?;
3529
3530    let mut connection_pool = session.connection_pool().await;
3531    if let Some(membership_update) = membership_update {
3532        notify_membership_updated(
3533            &mut connection_pool,
3534            membership_update,
3535            session.user_id(),
3536            &session.peer,
3537        );
3538    } else {
3539        let update = proto::UpdateChannels {
3540            remove_channel_invitations: vec![channel_id.to_proto()],
3541            ..Default::default()
3542        };
3543
3544        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3545            session.peer.send(connection_id, update.clone())?;
3546        }
3547    };
3548
3549    send_notifications(&connection_pool, &session.peer, notifications);
3550
3551    response.send(proto::Ack {})?;
3552
3553    Ok(())
3554}
3555
3556/// Join the channels' room
3557async fn join_channel(
3558    request: proto::JoinChannel,
3559    response: Response<proto::JoinChannel>,
3560    session: MessageContext,
3561) -> Result<()> {
3562    let channel_id = ChannelId::from_proto(request.channel_id);
3563    join_channel_internal(channel_id, Box::new(response), session).await
3564}
3565
3566trait JoinChannelInternalResponse {
3567    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3568}
3569impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3570    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3571        Response::<proto::JoinChannel>::send(self, result)
3572    }
3573}
3574impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3575    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3576        Response::<proto::JoinRoom>::send(self, result)
3577    }
3578}
3579
3580async fn join_channel_internal(
3581    channel_id: ChannelId,
3582    response: Box<impl JoinChannelInternalResponse>,
3583    session: MessageContext,
3584) -> Result<()> {
3585    let joined_room = {
3586        let mut db = session.db().await;
3587        // If zed quits without leaving the room, and the user re-opens zed before the
3588        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3589        // room they were in.
3590        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3591            tracing::info!(
3592                stale_connection_id = %connection,
3593                "cleaning up stale connection",
3594            );
3595            drop(db);
3596            leave_room_for_session(&session, connection).await?;
3597            db = session.db().await;
3598        }
3599
3600        let (joined_room, membership_updated, role) = db
3601            .join_channel(channel_id, session.user_id(), session.connection_id)
3602            .await?;
3603
3604        let live_kit_connection_info =
3605            session
3606                .app_state
3607                .livekit_client
3608                .as_ref()
3609                .and_then(|live_kit| {
3610                    let (can_publish, token) = if role == ChannelRole::Guest {
3611                        (
3612                            false,
3613                            live_kit
3614                                .guest_token(
3615                                    &joined_room.room.livekit_room,
3616                                    &session.user_id().to_string(),
3617                                )
3618                                .trace_err()?,
3619                        )
3620                    } else {
3621                        (
3622                            true,
3623                            live_kit
3624                                .room_token(
3625                                    &joined_room.room.livekit_room,
3626                                    &session.user_id().to_string(),
3627                                )
3628                                .trace_err()?,
3629                        )
3630                    };
3631
3632                    Some(LiveKitConnectionInfo {
3633                        server_url: live_kit.url().into(),
3634                        token,
3635                        can_publish,
3636                    })
3637                });
3638
3639        response.send(proto::JoinRoomResponse {
3640            room: Some(joined_room.room.clone()),
3641            channel_id: joined_room
3642                .channel
3643                .as_ref()
3644                .map(|channel| channel.id.to_proto()),
3645            live_kit_connection_info,
3646        })?;
3647
3648        let mut connection_pool = session.connection_pool().await;
3649        if let Some(membership_updated) = membership_updated {
3650            notify_membership_updated(
3651                &mut connection_pool,
3652                membership_updated,
3653                session.user_id(),
3654                &session.peer,
3655            );
3656        }
3657
3658        room_updated(&joined_room.room, &session.peer);
3659
3660        joined_room
3661    };
3662
3663    channel_updated(
3664        &joined_room.channel.context("channel not returned")?,
3665        &joined_room.room,
3666        &session.peer,
3667        &*session.connection_pool().await,
3668    );
3669
3670    update_user_contacts(session.user_id(), &session).await?;
3671    Ok(())
3672}
3673
3674/// Start editing the channel notes
3675async fn join_channel_buffer(
3676    request: proto::JoinChannelBuffer,
3677    response: Response<proto::JoinChannelBuffer>,
3678    session: MessageContext,
3679) -> Result<()> {
3680    let db = session.db().await;
3681    let channel_id = ChannelId::from_proto(request.channel_id);
3682
3683    let open_response = db
3684        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3685        .await?;
3686
3687    let collaborators = open_response.collaborators.clone();
3688    response.send(open_response)?;
3689
3690    let update = UpdateChannelBufferCollaborators {
3691        channel_id: channel_id.to_proto(),
3692        collaborators: collaborators.clone(),
3693    };
3694    channel_buffer_updated(
3695        session.connection_id,
3696        collaborators
3697            .iter()
3698            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3699        &update,
3700        &session.peer,
3701    );
3702
3703    Ok(())
3704}
3705
3706/// Edit the channel notes
3707async fn update_channel_buffer(
3708    request: proto::UpdateChannelBuffer,
3709    session: MessageContext,
3710) -> Result<()> {
3711    let db = session.db().await;
3712    let channel_id = ChannelId::from_proto(request.channel_id);
3713
3714    let (collaborators, epoch, version) = db
3715        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3716        .await?;
3717
3718    channel_buffer_updated(
3719        session.connection_id,
3720        collaborators.clone(),
3721        &proto::UpdateChannelBuffer {
3722            channel_id: channel_id.to_proto(),
3723            operations: request.operations,
3724        },
3725        &session.peer,
3726    );
3727
3728    let pool = &*session.connection_pool().await;
3729
3730    let non_collaborators =
3731        pool.channel_connection_ids(channel_id)
3732            .filter_map(|(connection_id, _)| {
3733                if collaborators.contains(&connection_id) {
3734                    None
3735                } else {
3736                    Some(connection_id)
3737                }
3738            });
3739
3740    broadcast(None, non_collaborators, |peer_id| {
3741        session.peer.send(
3742            peer_id,
3743            proto::UpdateChannels {
3744                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3745                    channel_id: channel_id.to_proto(),
3746                    epoch: epoch as u64,
3747                    version: version.clone(),
3748                }],
3749                ..Default::default()
3750            },
3751        )
3752    });
3753
3754    Ok(())
3755}
3756
3757/// Rejoin the channel notes after a connection blip
3758async fn rejoin_channel_buffers(
3759    request: proto::RejoinChannelBuffers,
3760    response: Response<proto::RejoinChannelBuffers>,
3761    session: MessageContext,
3762) -> Result<()> {
3763    let db = session.db().await;
3764    let buffers = db
3765        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3766        .await?;
3767
3768    for rejoined_buffer in &buffers {
3769        let collaborators_to_notify = rejoined_buffer
3770            .buffer
3771            .collaborators
3772            .iter()
3773            .filter_map(|c| Some(c.peer_id?.into()));
3774        channel_buffer_updated(
3775            session.connection_id,
3776            collaborators_to_notify,
3777            &proto::UpdateChannelBufferCollaborators {
3778                channel_id: rejoined_buffer.buffer.channel_id,
3779                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3780            },
3781            &session.peer,
3782        );
3783    }
3784
3785    response.send(proto::RejoinChannelBuffersResponse {
3786        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3787    })?;
3788
3789    Ok(())
3790}
3791
3792/// Stop editing the channel notes
3793async fn leave_channel_buffer(
3794    request: proto::LeaveChannelBuffer,
3795    response: Response<proto::LeaveChannelBuffer>,
3796    session: MessageContext,
3797) -> Result<()> {
3798    let db = session.db().await;
3799    let channel_id = ChannelId::from_proto(request.channel_id);
3800
3801    let left_buffer = db
3802        .leave_channel_buffer(channel_id, session.connection_id)
3803        .await?;
3804
3805    response.send(Ack {})?;
3806
3807    channel_buffer_updated(
3808        session.connection_id,
3809        left_buffer.connections,
3810        &proto::UpdateChannelBufferCollaborators {
3811            channel_id: channel_id.to_proto(),
3812            collaborators: left_buffer.collaborators,
3813        },
3814        &session.peer,
3815    );
3816
3817    Ok(())
3818}
3819
3820fn channel_buffer_updated<T: EnvelopedMessage>(
3821    sender_id: ConnectionId,
3822    collaborators: impl IntoIterator<Item = ConnectionId>,
3823    message: &T,
3824    peer: &Peer,
3825) {
3826    broadcast(Some(sender_id), collaborators, |peer_id| {
3827        peer.send(peer_id, message.clone())
3828    });
3829}
3830
3831fn send_notifications(
3832    connection_pool: &ConnectionPool,
3833    peer: &Peer,
3834    notifications: db::NotificationBatch,
3835) {
3836    for (user_id, notification) in notifications {
3837        for connection_id in connection_pool.user_connection_ids(user_id) {
3838            if let Err(error) = peer.send(
3839                connection_id,
3840                proto::AddNotification {
3841                    notification: Some(notification.clone()),
3842                },
3843            ) {
3844                tracing::error!(
3845                    "failed to send notification to {:?} {}",
3846                    connection_id,
3847                    error
3848                );
3849            }
3850        }
3851    }
3852}
3853
3854/// Send a message to the channel
3855async fn send_channel_message(
3856    request: proto::SendChannelMessage,
3857    response: Response<proto::SendChannelMessage>,
3858    session: MessageContext,
3859) -> Result<()> {
3860    // Validate the message body.
3861    let body = request.body.trim().to_string();
3862    if body.len() > MAX_MESSAGE_LEN {
3863        return Err(anyhow!("message is too long"))?;
3864    }
3865    if body.is_empty() {
3866        return Err(anyhow!("message can't be blank"))?;
3867    }
3868
3869    // TODO: adjust mentions if body is trimmed
3870
3871    let timestamp = OffsetDateTime::now_utc();
3872    let nonce = request.nonce.context("nonce can't be blank")?;
3873
3874    let channel_id = ChannelId::from_proto(request.channel_id);
3875    let CreatedChannelMessage {
3876        message_id,
3877        participant_connection_ids,
3878        notifications,
3879    } = session
3880        .db()
3881        .await
3882        .create_channel_message(
3883            channel_id,
3884            session.user_id(),
3885            &body,
3886            &request.mentions,
3887            timestamp,
3888            nonce.clone().into(),
3889            request.reply_to_message_id.map(MessageId::from_proto),
3890        )
3891        .await?;
3892
3893    let message = proto::ChannelMessage {
3894        sender_id: session.user_id().to_proto(),
3895        id: message_id.to_proto(),
3896        body,
3897        mentions: request.mentions,
3898        timestamp: timestamp.unix_timestamp() as u64,
3899        nonce: Some(nonce),
3900        reply_to_message_id: request.reply_to_message_id,
3901        edited_at: None,
3902    };
3903    broadcast(
3904        Some(session.connection_id),
3905        participant_connection_ids.clone(),
3906        |connection| {
3907            session.peer.send(
3908                connection,
3909                proto::ChannelMessageSent {
3910                    channel_id: channel_id.to_proto(),
3911                    message: Some(message.clone()),
3912                },
3913            )
3914        },
3915    );
3916    response.send(proto::SendChannelMessageResponse {
3917        message: Some(message),
3918    })?;
3919
3920    let pool = &*session.connection_pool().await;
3921    let non_participants =
3922        pool.channel_connection_ids(channel_id)
3923            .filter_map(|(connection_id, _)| {
3924                if participant_connection_ids.contains(&connection_id) {
3925                    None
3926                } else {
3927                    Some(connection_id)
3928                }
3929            });
3930    broadcast(None, non_participants, |peer_id| {
3931        session.peer.send(
3932            peer_id,
3933            proto::UpdateChannels {
3934                latest_channel_message_ids: vec![proto::ChannelMessageId {
3935                    channel_id: channel_id.to_proto(),
3936                    message_id: message_id.to_proto(),
3937                }],
3938                ..Default::default()
3939            },
3940        )
3941    });
3942    send_notifications(pool, &session.peer, notifications);
3943
3944    Ok(())
3945}
3946
3947/// Delete a channel message
3948async fn remove_channel_message(
3949    request: proto::RemoveChannelMessage,
3950    response: Response<proto::RemoveChannelMessage>,
3951    session: MessageContext,
3952) -> Result<()> {
3953    let channel_id = ChannelId::from_proto(request.channel_id);
3954    let message_id = MessageId::from_proto(request.message_id);
3955    let (connection_ids, existing_notification_ids) = session
3956        .db()
3957        .await
3958        .remove_channel_message(channel_id, message_id, session.user_id())
3959        .await?;
3960
3961    broadcast(
3962        Some(session.connection_id),
3963        connection_ids,
3964        move |connection| {
3965            session.peer.send(connection, request.clone())?;
3966
3967            for notification_id in &existing_notification_ids {
3968                session.peer.send(
3969                    connection,
3970                    proto::DeleteNotification {
3971                        notification_id: (*notification_id).to_proto(),
3972                    },
3973                )?;
3974            }
3975
3976            Ok(())
3977        },
3978    );
3979    response.send(proto::Ack {})?;
3980    Ok(())
3981}
3982
3983async fn update_channel_message(
3984    request: proto::UpdateChannelMessage,
3985    response: Response<proto::UpdateChannelMessage>,
3986    session: MessageContext,
3987) -> Result<()> {
3988    let channel_id = ChannelId::from_proto(request.channel_id);
3989    let message_id = MessageId::from_proto(request.message_id);
3990    let updated_at = OffsetDateTime::now_utc();
3991    let UpdatedChannelMessage {
3992        message_id,
3993        participant_connection_ids,
3994        notifications,
3995        reply_to_message_id,
3996        timestamp,
3997        deleted_mention_notification_ids,
3998        updated_mention_notifications,
3999    } = session
4000        .db()
4001        .await
4002        .update_channel_message(
4003            channel_id,
4004            message_id,
4005            session.user_id(),
4006            request.body.as_str(),
4007            &request.mentions,
4008            updated_at,
4009        )
4010        .await?;
4011
4012    let nonce = request.nonce.clone().context("nonce can't be blank")?;
4013
4014    let message = proto::ChannelMessage {
4015        sender_id: session.user_id().to_proto(),
4016        id: message_id.to_proto(),
4017        body: request.body.clone(),
4018        mentions: request.mentions.clone(),
4019        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4020        nonce: Some(nonce),
4021        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4022        edited_at: Some(updated_at.unix_timestamp() as u64),
4023    };
4024
4025    response.send(proto::Ack {})?;
4026
4027    let pool = &*session.connection_pool().await;
4028    broadcast(
4029        Some(session.connection_id),
4030        participant_connection_ids,
4031        |connection| {
4032            session.peer.send(
4033                connection,
4034                proto::ChannelMessageUpdate {
4035                    channel_id: channel_id.to_proto(),
4036                    message: Some(message.clone()),
4037                },
4038            )?;
4039
4040            for notification_id in &deleted_mention_notification_ids {
4041                session.peer.send(
4042                    connection,
4043                    proto::DeleteNotification {
4044                        notification_id: (*notification_id).to_proto(),
4045                    },
4046                )?;
4047            }
4048
4049            for notification in &updated_mention_notifications {
4050                session.peer.send(
4051                    connection,
4052                    proto::UpdateNotification {
4053                        notification: Some(notification.clone()),
4054                    },
4055                )?;
4056            }
4057
4058            Ok(())
4059        },
4060    );
4061
4062    send_notifications(pool, &session.peer, notifications);
4063
4064    Ok(())
4065}
4066
4067/// Mark a channel message as read
4068async fn acknowledge_channel_message(
4069    request: proto::AckChannelMessage,
4070    session: MessageContext,
4071) -> Result<()> {
4072    let channel_id = ChannelId::from_proto(request.channel_id);
4073    let message_id = MessageId::from_proto(request.message_id);
4074    let notifications = session
4075        .db()
4076        .await
4077        .observe_channel_message(channel_id, session.user_id(), message_id)
4078        .await?;
4079    send_notifications(
4080        &*session.connection_pool().await,
4081        &session.peer,
4082        notifications,
4083    );
4084    Ok(())
4085}
4086
4087/// Mark a buffer version as synced
4088async fn acknowledge_buffer_version(
4089    request: proto::AckBufferOperation,
4090    session: MessageContext,
4091) -> Result<()> {
4092    let buffer_id = BufferId::from_proto(request.buffer_id);
4093    session
4094        .db()
4095        .await
4096        .observe_buffer_version(
4097            buffer_id,
4098            session.user_id(),
4099            request.epoch as i32,
4100            &request.version,
4101        )
4102        .await?;
4103    Ok(())
4104}
4105
4106/// Get a Supermaven API key for the user
4107async fn get_supermaven_api_key(
4108    _request: proto::GetSupermavenApiKey,
4109    response: Response<proto::GetSupermavenApiKey>,
4110    session: MessageContext,
4111) -> Result<()> {
4112    let user_id: String = session.user_id().to_string();
4113    if !session.is_staff() {
4114        return Err(anyhow!("supermaven not enabled for this account"))?;
4115    }
4116
4117    let email = session.email().context("user must have an email")?;
4118
4119    let supermaven_admin_api = session
4120        .supermaven_client
4121        .as_ref()
4122        .context("supermaven not configured")?;
4123
4124    let result = supermaven_admin_api
4125        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4126        .await?;
4127
4128    response.send(proto::GetSupermavenApiKeyResponse {
4129        api_key: result.api_key,
4130    })?;
4131
4132    Ok(())
4133}
4134
4135/// Start receiving chat updates for a channel
4136async fn join_channel_chat(
4137    request: proto::JoinChannelChat,
4138    response: Response<proto::JoinChannelChat>,
4139    session: MessageContext,
4140) -> Result<()> {
4141    let channel_id = ChannelId::from_proto(request.channel_id);
4142
4143    let db = session.db().await;
4144    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4145        .await?;
4146    let messages = db
4147        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4148        .await?;
4149    response.send(proto::JoinChannelChatResponse {
4150        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4151        messages,
4152    })?;
4153    Ok(())
4154}
4155
4156/// Stop receiving chat updates for a channel
4157async fn leave_channel_chat(
4158    request: proto::LeaveChannelChat,
4159    session: MessageContext,
4160) -> Result<()> {
4161    let channel_id = ChannelId::from_proto(request.channel_id);
4162    session
4163        .db()
4164        .await
4165        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4166        .await?;
4167    Ok(())
4168}
4169
4170/// Retrieve the chat history for a channel
4171async fn get_channel_messages(
4172    request: proto::GetChannelMessages,
4173    response: Response<proto::GetChannelMessages>,
4174    session: MessageContext,
4175) -> Result<()> {
4176    let channel_id = ChannelId::from_proto(request.channel_id);
4177    let messages = session
4178        .db()
4179        .await
4180        .get_channel_messages(
4181            channel_id,
4182            session.user_id(),
4183            MESSAGE_COUNT_PER_PAGE,
4184            Some(MessageId::from_proto(request.before_message_id)),
4185        )
4186        .await?;
4187    response.send(proto::GetChannelMessagesResponse {
4188        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4189        messages,
4190    })?;
4191    Ok(())
4192}
4193
4194/// Retrieve specific chat messages
4195async fn get_channel_messages_by_id(
4196    request: proto::GetChannelMessagesById,
4197    response: Response<proto::GetChannelMessagesById>,
4198    session: MessageContext,
4199) -> Result<()> {
4200    let message_ids = request
4201        .message_ids
4202        .iter()
4203        .map(|id| MessageId::from_proto(*id))
4204        .collect::<Vec<_>>();
4205    let messages = session
4206        .db()
4207        .await
4208        .get_channel_messages_by_id(session.user_id(), &message_ids)
4209        .await?;
4210    response.send(proto::GetChannelMessagesResponse {
4211        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4212        messages,
4213    })?;
4214    Ok(())
4215}
4216
4217/// Retrieve the current users notifications
4218async fn get_notifications(
4219    request: proto::GetNotifications,
4220    response: Response<proto::GetNotifications>,
4221    session: MessageContext,
4222) -> Result<()> {
4223    let notifications = session
4224        .db()
4225        .await
4226        .get_notifications(
4227            session.user_id(),
4228            NOTIFICATION_COUNT_PER_PAGE,
4229            request.before_id.map(db::NotificationId::from_proto),
4230        )
4231        .await?;
4232    response.send(proto::GetNotificationsResponse {
4233        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4234        notifications,
4235    })?;
4236    Ok(())
4237}
4238
4239/// Mark notifications as read
4240async fn mark_notification_as_read(
4241    request: proto::MarkNotificationRead,
4242    response: Response<proto::MarkNotificationRead>,
4243    session: MessageContext,
4244) -> Result<()> {
4245    let database = &session.db().await;
4246    let notifications = database
4247        .mark_notification_as_read_by_id(
4248            session.user_id(),
4249            NotificationId::from_proto(request.notification_id),
4250        )
4251        .await?;
4252    send_notifications(
4253        &*session.connection_pool().await,
4254        &session.peer,
4255        notifications,
4256    );
4257    response.send(proto::Ack {})?;
4258    Ok(())
4259}
4260
4261/// Get the current users information
4262async fn get_private_user_info(
4263    _request: proto::GetPrivateUserInfo,
4264    response: Response<proto::GetPrivateUserInfo>,
4265    session: MessageContext,
4266) -> Result<()> {
4267    let db = session.db().await;
4268
4269    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4270    let user = db
4271        .get_user_by_id(session.user_id())
4272        .await?
4273        .context("user not found")?;
4274    let flags = db.get_user_flags(session.user_id()).await?;
4275
4276    response.send(proto::GetPrivateUserInfoResponse {
4277        metrics_id,
4278        staff: user.admin,
4279        flags,
4280        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4281    })?;
4282    Ok(())
4283}
4284
4285/// Accept the terms of service (tos) on behalf of the current user
4286async fn accept_terms_of_service(
4287    _request: proto::AcceptTermsOfService,
4288    response: Response<proto::AcceptTermsOfService>,
4289    session: MessageContext,
4290) -> Result<()> {
4291    let db = session.db().await;
4292
4293    let accepted_tos_at = Utc::now();
4294    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4295        .await?;
4296
4297    response.send(proto::AcceptTermsOfServiceResponse {
4298        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4299    })?;
4300
4301    // When the user accepts the terms of service, we want to refresh their LLM
4302    // token to grant access.
4303    session
4304        .peer
4305        .send(session.connection_id, proto::RefreshLlmToken {})?;
4306
4307    Ok(())
4308}
4309
4310async fn get_llm_api_token(
4311    _request: proto::GetLlmToken,
4312    response: Response<proto::GetLlmToken>,
4313    session: MessageContext,
4314) -> Result<()> {
4315    let db = session.db().await;
4316
4317    let flags = db.get_user_flags(session.user_id()).await?;
4318
4319    let user_id = session.user_id();
4320    let user = db
4321        .get_user_by_id(user_id)
4322        .await?
4323        .with_context(|| format!("user {user_id} not found"))?;
4324
4325    if user.accepted_tos_at.is_none() {
4326        Err(anyhow!("terms of service not accepted"))?
4327    }
4328
4329    let stripe_client = session
4330        .app_state
4331        .stripe_client
4332        .as_ref()
4333        .context("failed to retrieve Stripe client")?;
4334
4335    let stripe_billing = session
4336        .app_state
4337        .stripe_billing
4338        .as_ref()
4339        .context("failed to retrieve Stripe billing object")?;
4340
4341    let billing_customer = if let Some(billing_customer) =
4342        db.get_billing_customer_by_user_id(user.id).await?
4343    {
4344        billing_customer
4345    } else {
4346        let customer_id = stripe_billing
4347            .find_or_create_customer_by_email(user.email_address.as_deref())
4348            .await?;
4349
4350        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4351            .await?
4352            .context("billing customer not found")?
4353    };
4354
4355    let billing_subscription =
4356        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4357            billing_subscription
4358        } else {
4359            let stripe_customer_id =
4360                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4361
4362            let stripe_subscription = stripe_billing
4363                .subscribe_to_zed_free(stripe_customer_id)
4364                .await?;
4365
4366            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4367                billing_customer_id: billing_customer.id,
4368                kind: Some(SubscriptionKind::ZedFree),
4369                stripe_subscription_id: stripe_subscription.id.to_string(),
4370                stripe_subscription_status: stripe_subscription.status.into(),
4371                stripe_cancellation_reason: None,
4372                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4373                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4374            })
4375            .await?
4376        };
4377
4378    let billing_preferences = db.get_billing_preferences(user.id).await?;
4379
4380    let token = LlmTokenClaims::create(
4381        &user,
4382        session.is_staff(),
4383        billing_customer,
4384        billing_preferences,
4385        &flags,
4386        billing_subscription,
4387        session.system_id.clone(),
4388        &session.app_state.config,
4389    )?;
4390    response.send(proto::GetLlmTokenResponse { token })?;
4391    Ok(())
4392}
4393
4394fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4395    let message = match message {
4396        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4397        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4398        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4399        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4400        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4401            code: frame.code.into(),
4402            reason: frame.reason.as_str().to_owned().into(),
4403        })),
4404        // We should never receive a frame while reading the message, according
4405        // to the `tungstenite` maintainers:
4406        //
4407        // > It cannot occur when you read messages from the WebSocket, but it
4408        // > can be used when you want to send the raw frames (e.g. you want to
4409        // > send the frames to the WebSocket without composing the full message first).
4410        // >
4411        // > — https://github.com/snapview/tungstenite-rs/issues/268
4412        TungsteniteMessage::Frame(_) => {
4413            bail!("received an unexpected frame while reading the message")
4414        }
4415    };
4416
4417    Ok(message)
4418}
4419
4420fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4421    match message {
4422        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4423        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4424        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4425        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4426        AxumMessage::Close(frame) => {
4427            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4428                code: frame.code.into(),
4429                reason: frame.reason.as_ref().into(),
4430            }))
4431        }
4432    }
4433}
4434
4435fn notify_membership_updated(
4436    connection_pool: &mut ConnectionPool,
4437    result: MembershipUpdated,
4438    user_id: UserId,
4439    peer: &Peer,
4440) {
4441    for membership in &result.new_channels.channel_memberships {
4442        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4443    }
4444    for channel_id in &result.removed_channels {
4445        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4446    }
4447
4448    let user_channels_update = proto::UpdateUserChannels {
4449        channel_memberships: result
4450            .new_channels
4451            .channel_memberships
4452            .iter()
4453            .map(|cm| proto::ChannelMembership {
4454                channel_id: cm.channel_id.to_proto(),
4455                role: cm.role.into(),
4456            })
4457            .collect(),
4458        ..Default::default()
4459    };
4460
4461    let mut update = build_channels_update(result.new_channels);
4462    update.delete_channels = result
4463        .removed_channels
4464        .into_iter()
4465        .map(|id| id.to_proto())
4466        .collect();
4467    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4468
4469    for connection_id in connection_pool.user_connection_ids(user_id) {
4470        peer.send(connection_id, user_channels_update.clone())
4471            .trace_err();
4472        peer.send(connection_id, update.clone()).trace_err();
4473    }
4474}
4475
4476fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4477    proto::UpdateUserChannels {
4478        channel_memberships: channels
4479            .channel_memberships
4480            .iter()
4481            .map(|m| proto::ChannelMembership {
4482                channel_id: m.channel_id.to_proto(),
4483                role: m.role.into(),
4484            })
4485            .collect(),
4486        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4487        observed_channel_message_id: channels.observed_channel_messages.clone(),
4488    }
4489}
4490
4491fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4492    let mut update = proto::UpdateChannels::default();
4493
4494    for channel in channels.channels {
4495        update.channels.push(channel.to_proto());
4496    }
4497
4498    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4499    update.latest_channel_message_ids = channels.latest_channel_messages;
4500
4501    for (channel_id, participants) in channels.channel_participants {
4502        update
4503            .channel_participants
4504            .push(proto::ChannelParticipants {
4505                channel_id: channel_id.to_proto(),
4506                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4507            });
4508    }
4509
4510    for channel in channels.invited_channels {
4511        update.channel_invitations.push(channel.to_proto());
4512    }
4513
4514    update
4515}
4516
4517fn build_initial_contacts_update(
4518    contacts: Vec<db::Contact>,
4519    pool: &ConnectionPool,
4520) -> proto::UpdateContacts {
4521    let mut update = proto::UpdateContacts::default();
4522
4523    for contact in contacts {
4524        match contact {
4525            db::Contact::Accepted { user_id, busy } => {
4526                update.contacts.push(contact_for_user(user_id, busy, pool));
4527            }
4528            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4529            db::Contact::Incoming { user_id } => {
4530                update
4531                    .incoming_requests
4532                    .push(proto::IncomingContactRequest {
4533                        requester_id: user_id.to_proto(),
4534                    })
4535            }
4536        }
4537    }
4538
4539    update
4540}
4541
4542fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4543    proto::Contact {
4544        user_id: user_id.to_proto(),
4545        online: pool.is_user_online(user_id),
4546        busy,
4547    }
4548}
4549
4550fn room_updated(room: &proto::Room, peer: &Peer) {
4551    broadcast(
4552        None,
4553        room.participants
4554            .iter()
4555            .filter_map(|participant| Some(participant.peer_id?.into())),
4556        |peer_id| {
4557            peer.send(
4558                peer_id,
4559                proto::RoomUpdated {
4560                    room: Some(room.clone()),
4561                },
4562            )
4563        },
4564    );
4565}
4566
4567fn channel_updated(
4568    channel: &db::channel::Model,
4569    room: &proto::Room,
4570    peer: &Peer,
4571    pool: &ConnectionPool,
4572) {
4573    let participants = room
4574        .participants
4575        .iter()
4576        .map(|p| p.user_id)
4577        .collect::<Vec<_>>();
4578
4579    broadcast(
4580        None,
4581        pool.channel_connection_ids(channel.root_id())
4582            .filter_map(|(channel_id, role)| {
4583                role.can_see_channel(channel.visibility)
4584                    .then_some(channel_id)
4585            }),
4586        |peer_id| {
4587            peer.send(
4588                peer_id,
4589                proto::UpdateChannels {
4590                    channel_participants: vec![proto::ChannelParticipants {
4591                        channel_id: channel.id.to_proto(),
4592                        participant_user_ids: participants.clone(),
4593                    }],
4594                    ..Default::default()
4595                },
4596            )
4597        },
4598    );
4599}
4600
4601async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4602    let db = session.db().await;
4603
4604    let contacts = db.get_contacts(user_id).await?;
4605    let busy = db.is_user_busy(user_id).await?;
4606
4607    let pool = session.connection_pool().await;
4608    let updated_contact = contact_for_user(user_id, busy, &pool);
4609    for contact in contacts {
4610        if let db::Contact::Accepted {
4611            user_id: contact_user_id,
4612            ..
4613        } = contact
4614        {
4615            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4616                session
4617                    .peer
4618                    .send(
4619                        contact_conn_id,
4620                        proto::UpdateContacts {
4621                            contacts: vec![updated_contact.clone()],
4622                            remove_contacts: Default::default(),
4623                            incoming_requests: Default::default(),
4624                            remove_incoming_requests: Default::default(),
4625                            outgoing_requests: Default::default(),
4626                            remove_outgoing_requests: Default::default(),
4627                        },
4628                    )
4629                    .trace_err();
4630            }
4631        }
4632    }
4633    Ok(())
4634}
4635
4636async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4637    let mut contacts_to_update = HashSet::default();
4638
4639    let room_id;
4640    let canceled_calls_to_user_ids;
4641    let livekit_room;
4642    let delete_livekit_room;
4643    let room;
4644    let channel;
4645
4646    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4647        contacts_to_update.insert(session.user_id());
4648
4649        for project in left_room.left_projects.values() {
4650            project_left(project, session);
4651        }
4652
4653        room_id = RoomId::from_proto(left_room.room.id);
4654        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4655        livekit_room = mem::take(&mut left_room.room.livekit_room);
4656        delete_livekit_room = left_room.deleted;
4657        room = mem::take(&mut left_room.room);
4658        channel = mem::take(&mut left_room.channel);
4659
4660        room_updated(&room, &session.peer);
4661    } else {
4662        return Ok(());
4663    }
4664
4665    if let Some(channel) = channel {
4666        channel_updated(
4667            &channel,
4668            &room,
4669            &session.peer,
4670            &*session.connection_pool().await,
4671        );
4672    }
4673
4674    {
4675        let pool = session.connection_pool().await;
4676        for canceled_user_id in canceled_calls_to_user_ids {
4677            for connection_id in pool.user_connection_ids(canceled_user_id) {
4678                session
4679                    .peer
4680                    .send(
4681                        connection_id,
4682                        proto::CallCanceled {
4683                            room_id: room_id.to_proto(),
4684                        },
4685                    )
4686                    .trace_err();
4687            }
4688            contacts_to_update.insert(canceled_user_id);
4689        }
4690    }
4691
4692    for contact_user_id in contacts_to_update {
4693        update_user_contacts(contact_user_id, session).await?;
4694    }
4695
4696    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4697        live_kit
4698            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4699            .await
4700            .trace_err();
4701
4702        if delete_livekit_room {
4703            live_kit.delete_room(livekit_room).await.trace_err();
4704        }
4705    }
4706
4707    Ok(())
4708}
4709
4710async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4711    let left_channel_buffers = session
4712        .db()
4713        .await
4714        .leave_channel_buffers(session.connection_id)
4715        .await?;
4716
4717    for left_buffer in left_channel_buffers {
4718        channel_buffer_updated(
4719            session.connection_id,
4720            left_buffer.connections,
4721            &proto::UpdateChannelBufferCollaborators {
4722                channel_id: left_buffer.channel_id.to_proto(),
4723                collaborators: left_buffer.collaborators,
4724            },
4725            &session.peer,
4726        );
4727    }
4728
4729    Ok(())
4730}
4731
4732fn project_left(project: &db::LeftProject, session: &Session) {
4733    for connection_id in &project.connection_ids {
4734        if project.should_unshare {
4735            session
4736                .peer
4737                .send(
4738                    *connection_id,
4739                    proto::UnshareProject {
4740                        project_id: project.id.to_proto(),
4741                    },
4742                )
4743                .trace_err();
4744        } else {
4745            session
4746                .peer
4747                .send(
4748                    *connection_id,
4749                    proto::RemoveProjectCollaborator {
4750                        project_id: project.id.to_proto(),
4751                        peer_id: Some(session.connection_id.into()),
4752                    },
4753                )
4754                .trace_err();
4755        }
4756    }
4757}
4758
4759pub trait ResultExt {
4760    type Ok;
4761
4762    fn trace_err(self) -> Option<Self::Ok>;
4763}
4764
4765impl<T, E> ResultExt for Result<T, E>
4766where
4767    E: std::fmt::Debug,
4768{
4769    type Ok = T;
4770
4771    #[track_caller]
4772    fn trace_err(self) -> Option<T> {
4773        match self {
4774            Ok(value) => Some(value),
4775            Err(error) => {
4776                tracing::error!("{:?}", error);
4777                None
4778            }
4779        }
4780    }
4781}