rpc.rs

   1mod connection_pool;
   2
   3use crate::api::billing::find_or_create_billing_customer;
   4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   5use crate::db::billing_subscription::SubscriptionKind;
   6use crate::llm::db::LlmDatabase;
   7use crate::llm::{
   8    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
   9    MIN_ACCOUNT_AGE_FOR_LLM_USE,
  10};
  11use crate::stripe_client::StripeCustomerId;
  12use crate::{
  13    AppState, Error, Result, auth,
  14    db::{
  15        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  16        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  17        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  18        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  19    },
  20    executor::Executor,
  21};
  22use anyhow::{Context as _, anyhow, bail};
  23use async_tungstenite::tungstenite::{
  24    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  25};
  26use axum::headers::UserAgent;
  27use axum::{
  28    Extension, Router, TypedHeader,
  29    body::Body,
  30    extract::{
  31        ConnectInfo, WebSocketUpgrade,
  32        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  33    },
  34    headers::{Header, HeaderName},
  35    http::StatusCode,
  36    middleware,
  37    response::IntoResponse,
  38    routing::get,
  39};
  40use chrono::Utc;
  41use collections::{HashMap, HashSet};
  42pub use connection_pool::{ConnectionPool, ZedVersion};
  43use core::fmt::{self, Debug, Formatter};
  44use futures::TryFutureExt as _;
  45use reqwest_client::ReqwestClient;
  46use rpc::proto::{MultiLspQuery, split_repository_update};
  47use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  48use tracing::Span;
  49
  50use futures::{
  51    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  52    stream::FuturesUnordered,
  53};
  54use prometheus::{IntGauge, register_int_gauge};
  55use rpc::{
  56    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  57    proto::{
  58        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  59        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  60    },
  61};
  62use semantic_version::SemanticVersion;
  63use serde::{Serialize, Serializer};
  64use std::{
  65    any::TypeId,
  66    future::Future,
  67    marker::PhantomData,
  68    mem,
  69    net::SocketAddr,
  70    ops::{Deref, DerefMut},
  71    rc::Rc,
  72    sync::{
  73        Arc, OnceLock,
  74        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  75    },
  76    time::{Duration, Instant},
  77};
  78use time::OffsetDateTime;
  79use tokio::sync::{Semaphore, watch};
  80use tower::ServiceBuilder;
  81use tracing::{
  82    Instrument,
  83    field::{self},
  84    info_span, instrument,
  85};
  86
  87pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  88
  89// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  90pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  91
  92const MESSAGE_COUNT_PER_PAGE: usize = 100;
  93const MAX_MESSAGE_LEN: usize = 1024;
  94const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  95const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  96
  97static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  98
  99const TOTAL_DURATION_MS: &str = "total_duration_ms";
 100const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
 101const QUEUE_DURATION_MS: &str = "queue_duration_ms";
 102const HOST_WAITING_MS: &str = "host_waiting_ms";
 103
 104type MessageHandler =
 105    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
 106
 107pub struct ConnectionGuard;
 108
 109impl ConnectionGuard {
 110    pub fn try_acquire() -> Result<Self, ()> {
 111        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 112        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 113            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 114            tracing::error!(
 115                "too many concurrent connections: {}",
 116                current_connections + 1
 117            );
 118            return Err(());
 119        }
 120        Ok(ConnectionGuard)
 121    }
 122}
 123
 124impl Drop for ConnectionGuard {
 125    fn drop(&mut self) {
 126        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 127    }
 128}
 129
 130struct Response<R> {
 131    peer: Arc<Peer>,
 132    receipt: Receipt<R>,
 133    responded: Arc<AtomicBool>,
 134}
 135
 136impl<R: RequestMessage> Response<R> {
 137    fn send(self, payload: R::Response) -> Result<()> {
 138        self.responded.store(true, SeqCst);
 139        self.peer.respond(self.receipt, payload)?;
 140        Ok(())
 141    }
 142}
 143
 144#[derive(Clone, Debug)]
 145pub enum Principal {
 146    User(User),
 147    Impersonated { user: User, admin: User },
 148}
 149
 150impl Principal {
 151    fn user(&self) -> &User {
 152        match self {
 153            Principal::User(user) => user,
 154            Principal::Impersonated { user, .. } => user,
 155        }
 156    }
 157
 158    fn update_span(&self, span: &tracing::Span) {
 159        match &self {
 160            Principal::User(user) => {
 161                span.record("user_id", user.id.0);
 162                span.record("login", &user.github_login);
 163            }
 164            Principal::Impersonated { user, admin } => {
 165                span.record("user_id", user.id.0);
 166                span.record("login", &user.github_login);
 167                span.record("impersonator", &admin.github_login);
 168            }
 169        }
 170    }
 171}
 172
 173#[derive(Clone)]
 174struct MessageContext {
 175    session: Session,
 176    span: tracing::Span,
 177}
 178
 179impl Deref for MessageContext {
 180    type Target = Session;
 181
 182    fn deref(&self) -> &Self::Target {
 183        &self.session
 184    }
 185}
 186
 187impl MessageContext {
 188    pub fn forward_request<T: RequestMessage>(
 189        &self,
 190        receiver_id: ConnectionId,
 191        request: T,
 192    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 193        let request_start_time = Instant::now();
 194        let span = self.span.clone();
 195        tracing::info!("start forwarding request");
 196        self.peer
 197            .forward_request(self.connection_id, receiver_id, request)
 198            .inspect(move |_| {
 199                span.record(
 200                    HOST_WAITING_MS,
 201                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 202                );
 203            })
 204            .inspect_err(|_| tracing::error!("error forwarding request"))
 205            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 206    }
 207}
 208
 209#[derive(Clone)]
 210struct Session {
 211    principal: Principal,
 212    connection_id: ConnectionId,
 213    db: Arc<tokio::sync::Mutex<DbHandle>>,
 214    peer: Arc<Peer>,
 215    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 216    app_state: Arc<AppState>,
 217    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 218    /// The GeoIP country code for the user.
 219    #[allow(unused)]
 220    geoip_country_code: Option<String>,
 221    system_id: Option<String>,
 222    _executor: Executor,
 223}
 224
 225impl Session {
 226    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 227        #[cfg(test)]
 228        tokio::task::yield_now().await;
 229        let guard = self.db.lock().await;
 230        #[cfg(test)]
 231        tokio::task::yield_now().await;
 232        guard
 233    }
 234
 235    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 236        #[cfg(test)]
 237        tokio::task::yield_now().await;
 238        let guard = self.connection_pool.lock();
 239        ConnectionPoolGuard {
 240            guard,
 241            _not_send: PhantomData,
 242        }
 243    }
 244
 245    fn is_staff(&self) -> bool {
 246        match &self.principal {
 247            Principal::User(user) => user.admin,
 248            Principal::Impersonated { .. } => true,
 249        }
 250    }
 251
 252    fn user_id(&self) -> UserId {
 253        match &self.principal {
 254            Principal::User(user) => user.id,
 255            Principal::Impersonated { user, .. } => user.id,
 256        }
 257    }
 258
 259    pub fn email(&self) -> Option<String> {
 260        match &self.principal {
 261            Principal::User(user) => user.email_address.clone(),
 262            Principal::Impersonated { user, .. } => user.email_address.clone(),
 263        }
 264    }
 265}
 266
 267impl Debug for Session {
 268    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 269        let mut result = f.debug_struct("Session");
 270        match &self.principal {
 271            Principal::User(user) => {
 272                result.field("user", &user.github_login);
 273            }
 274            Principal::Impersonated { user, admin } => {
 275                result.field("user", &user.github_login);
 276                result.field("impersonator", &admin.github_login);
 277            }
 278        }
 279        result.field("connection_id", &self.connection_id).finish()
 280    }
 281}
 282
 283struct DbHandle(Arc<Database>);
 284
 285impl Deref for DbHandle {
 286    type Target = Database;
 287
 288    fn deref(&self) -> &Self::Target {
 289        self.0.as_ref()
 290    }
 291}
 292
 293pub struct Server {
 294    id: parking_lot::Mutex<ServerId>,
 295    peer: Arc<Peer>,
 296    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 297    app_state: Arc<AppState>,
 298    handlers: HashMap<TypeId, MessageHandler>,
 299    teardown: watch::Sender<bool>,
 300}
 301
 302pub(crate) struct ConnectionPoolGuard<'a> {
 303    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 304    _not_send: PhantomData<Rc<()>>,
 305}
 306
 307#[derive(Serialize)]
 308pub struct ServerSnapshot<'a> {
 309    peer: &'a Peer,
 310    #[serde(serialize_with = "serialize_deref")]
 311    connection_pool: ConnectionPoolGuard<'a>,
 312}
 313
 314pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 315where
 316    S: Serializer,
 317    T: Deref<Target = U>,
 318    U: Serialize,
 319{
 320    Serialize::serialize(value.deref(), serializer)
 321}
 322
 323impl Server {
 324    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 325        let mut server = Self {
 326            id: parking_lot::Mutex::new(id),
 327            peer: Peer::new(id.0 as u32),
 328            app_state: app_state.clone(),
 329            connection_pool: Default::default(),
 330            handlers: Default::default(),
 331            teardown: watch::channel(false).0,
 332        };
 333
 334        server
 335            .add_request_handler(ping)
 336            .add_request_handler(create_room)
 337            .add_request_handler(join_room)
 338            .add_request_handler(rejoin_room)
 339            .add_request_handler(leave_room)
 340            .add_request_handler(set_room_participant_role)
 341            .add_request_handler(call)
 342            .add_request_handler(cancel_call)
 343            .add_message_handler(decline_call)
 344            .add_request_handler(update_participant_location)
 345            .add_request_handler(share_project)
 346            .add_message_handler(unshare_project)
 347            .add_request_handler(join_project)
 348            .add_message_handler(leave_project)
 349            .add_request_handler(update_project)
 350            .add_request_handler(update_worktree)
 351            .add_request_handler(update_repository)
 352            .add_request_handler(remove_repository)
 353            .add_message_handler(start_language_server)
 354            .add_message_handler(update_language_server)
 355            .add_message_handler(update_diagnostic_summary)
 356            .add_message_handler(update_worktree_settings)
 357            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 358            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 359            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 360            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 361            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 362            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 363            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 364            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 365            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 366            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 367            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 368            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 369            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 370            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 371            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 372            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 373            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 374            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 375            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 376            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 377            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 378            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 379            .add_request_handler(
 380                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 381            )
 382            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 383            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 384            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 385            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 386            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 387            .add_request_handler(
 388                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 389            )
 390            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 391            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 392            .add_request_handler(
 393                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 394            )
 395            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 396            .add_request_handler(
 397                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 398            )
 399            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 400            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 401            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 402            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 403            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 404            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 405            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 406            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 407            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 408            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 409            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 410            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 411            .add_request_handler(
 412                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 413            )
 414            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 415            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 416            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 417            .add_request_handler(multi_lsp_query)
 418            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 419            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 420            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 421            .add_message_handler(create_buffer_for_peer)
 422            .add_request_handler(update_buffer)
 423            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 424            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 425            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 426            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 427            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 428            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 429            .add_message_handler(
 430                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 431            )
 432            .add_request_handler(get_users)
 433            .add_request_handler(fuzzy_search_users)
 434            .add_request_handler(request_contact)
 435            .add_request_handler(remove_contact)
 436            .add_request_handler(respond_to_contact_request)
 437            .add_message_handler(subscribe_to_channels)
 438            .add_request_handler(create_channel)
 439            .add_request_handler(delete_channel)
 440            .add_request_handler(invite_channel_member)
 441            .add_request_handler(remove_channel_member)
 442            .add_request_handler(set_channel_member_role)
 443            .add_request_handler(set_channel_visibility)
 444            .add_request_handler(rename_channel)
 445            .add_request_handler(join_channel_buffer)
 446            .add_request_handler(leave_channel_buffer)
 447            .add_message_handler(update_channel_buffer)
 448            .add_request_handler(rejoin_channel_buffers)
 449            .add_request_handler(get_channel_members)
 450            .add_request_handler(respond_to_channel_invite)
 451            .add_request_handler(join_channel)
 452            .add_request_handler(join_channel_chat)
 453            .add_message_handler(leave_channel_chat)
 454            .add_request_handler(send_channel_message)
 455            .add_request_handler(remove_channel_message)
 456            .add_request_handler(update_channel_message)
 457            .add_request_handler(get_channel_messages)
 458            .add_request_handler(get_channel_messages_by_id)
 459            .add_request_handler(get_notifications)
 460            .add_request_handler(mark_notification_as_read)
 461            .add_request_handler(move_channel)
 462            .add_request_handler(reorder_channel)
 463            .add_request_handler(follow)
 464            .add_message_handler(unfollow)
 465            .add_message_handler(update_followers)
 466            .add_request_handler(get_private_user_info)
 467            .add_request_handler(get_llm_api_token)
 468            .add_request_handler(accept_terms_of_service)
 469            .add_message_handler(acknowledge_channel_message)
 470            .add_message_handler(acknowledge_buffer_version)
 471            .add_request_handler(get_supermaven_api_key)
 472            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 473            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 474            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 475            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 476            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 477            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 478            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 479            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 480            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 481            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 482            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 483            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 484            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 485            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 486            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 487            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 488            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 489            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 490            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 491            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 492            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 493            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 494            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 495            .add_message_handler(update_context);
 496
 497        Arc::new(server)
 498    }
 499
 500    pub async fn start(&self) -> Result<()> {
 501        let server_id = *self.id.lock();
 502        let app_state = self.app_state.clone();
 503        let peer = self.peer.clone();
 504        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 505        let pool = self.connection_pool.clone();
 506        let livekit_client = self.app_state.livekit_client.clone();
 507
 508        let span = info_span!("start server");
 509        self.app_state.executor.spawn_detached(
 510            async move {
 511                tracing::info!("waiting for cleanup timeout");
 512                timeout.await;
 513                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 514
 515                app_state
 516                    .db
 517                    .delete_stale_channel_chat_participants(
 518                        &app_state.config.zed_environment,
 519                        server_id,
 520                    )
 521                    .await
 522                    .trace_err();
 523
 524                if let Some((room_ids, channel_ids)) = app_state
 525                    .db
 526                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 527                    .await
 528                    .trace_err()
 529                {
 530                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 531                    tracing::info!(
 532                        stale_channel_buffer_count = channel_ids.len(),
 533                        "retrieved stale channel buffers"
 534                    );
 535
 536                    for channel_id in channel_ids {
 537                        if let Some(refreshed_channel_buffer) = app_state
 538                            .db
 539                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 540                            .await
 541                            .trace_err()
 542                        {
 543                            for connection_id in refreshed_channel_buffer.connection_ids {
 544                                peer.send(
 545                                    connection_id,
 546                                    proto::UpdateChannelBufferCollaborators {
 547                                        channel_id: channel_id.to_proto(),
 548                                        collaborators: refreshed_channel_buffer
 549                                            .collaborators
 550                                            .clone(),
 551                                    },
 552                                )
 553                                .trace_err();
 554                            }
 555                        }
 556                    }
 557
 558                    for room_id in room_ids {
 559                        let mut contacts_to_update = HashSet::default();
 560                        let mut canceled_calls_to_user_ids = Vec::new();
 561                        let mut livekit_room = String::new();
 562                        let mut delete_livekit_room = false;
 563
 564                        if let Some(mut refreshed_room) = app_state
 565                            .db
 566                            .clear_stale_room_participants(room_id, server_id)
 567                            .await
 568                            .trace_err()
 569                        {
 570                            tracing::info!(
 571                                room_id = room_id.0,
 572                                new_participant_count = refreshed_room.room.participants.len(),
 573                                "refreshed room"
 574                            );
 575                            room_updated(&refreshed_room.room, &peer);
 576                            if let Some(channel) = refreshed_room.channel.as_ref() {
 577                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 578                            }
 579                            contacts_to_update
 580                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 581                            contacts_to_update
 582                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 583                            canceled_calls_to_user_ids =
 584                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 585                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 586                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 587                        }
 588
 589                        {
 590                            let pool = pool.lock();
 591                            for canceled_user_id in canceled_calls_to_user_ids {
 592                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 593                                    peer.send(
 594                                        connection_id,
 595                                        proto::CallCanceled {
 596                                            room_id: room_id.to_proto(),
 597                                        },
 598                                    )
 599                                    .trace_err();
 600                                }
 601                            }
 602                        }
 603
 604                        for user_id in contacts_to_update {
 605                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 606                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 607                            if let Some((busy, contacts)) = busy.zip(contacts) {
 608                                let pool = pool.lock();
 609                                let updated_contact = contact_for_user(user_id, busy, &pool);
 610                                for contact in contacts {
 611                                    if let db::Contact::Accepted {
 612                                        user_id: contact_user_id,
 613                                        ..
 614                                    } = contact
 615                                    {
 616                                        for contact_conn_id in
 617                                            pool.user_connection_ids(contact_user_id)
 618                                        {
 619                                            peer.send(
 620                                                contact_conn_id,
 621                                                proto::UpdateContacts {
 622                                                    contacts: vec![updated_contact.clone()],
 623                                                    remove_contacts: Default::default(),
 624                                                    incoming_requests: Default::default(),
 625                                                    remove_incoming_requests: Default::default(),
 626                                                    outgoing_requests: Default::default(),
 627                                                    remove_outgoing_requests: Default::default(),
 628                                                },
 629                                            )
 630                                            .trace_err();
 631                                        }
 632                                    }
 633                                }
 634                            }
 635                        }
 636
 637                        if let Some(live_kit) = livekit_client.as_ref() {
 638                            if delete_livekit_room {
 639                                live_kit.delete_room(livekit_room).await.trace_err();
 640                            }
 641                        }
 642                    }
 643                }
 644
 645                app_state
 646                    .db
 647                    .delete_stale_channel_chat_participants(
 648                        &app_state.config.zed_environment,
 649                        server_id,
 650                    )
 651                    .await
 652                    .trace_err();
 653
 654                app_state
 655                    .db
 656                    .clear_old_worktree_entries(server_id)
 657                    .await
 658                    .trace_err();
 659
 660                app_state
 661                    .db
 662                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 663                    .await
 664                    .trace_err();
 665            }
 666            .instrument(span),
 667        );
 668        Ok(())
 669    }
 670
 671    pub fn teardown(&self) {
 672        self.peer.teardown();
 673        self.connection_pool.lock().reset();
 674        let _ = self.teardown.send(true);
 675    }
 676
 677    #[cfg(test)]
 678    pub fn reset(&self, id: ServerId) {
 679        self.teardown();
 680        *self.id.lock() = id;
 681        self.peer.reset(id.0 as u32);
 682        let _ = self.teardown.send(false);
 683    }
 684
 685    #[cfg(test)]
 686    pub fn id(&self) -> ServerId {
 687        *self.id.lock()
 688    }
 689
 690    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 691    where
 692        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 693        Fut: 'static + Send + Future<Output = Result<()>>,
 694        M: EnvelopedMessage,
 695    {
 696        let prev_handler = self.handlers.insert(
 697            TypeId::of::<M>(),
 698            Box::new(move |envelope, session, span| {
 699                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 700                let received_at = envelope.received_at;
 701                tracing::info!("message received");
 702                let start_time = Instant::now();
 703                let future = (handler)(
 704                    *envelope,
 705                    MessageContext {
 706                        session,
 707                        span: span.clone(),
 708                    },
 709                );
 710                async move {
 711                    let result = future.await;
 712                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 713                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 714                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 715                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 716                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 717                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 718                    match result {
 719                        Err(error) => {
 720                            tracing::error!(?error, "error handling message")
 721                        }
 722                        Ok(()) => tracing::info!("finished handling message"),
 723                    }
 724                }
 725                .boxed()
 726            }),
 727        );
 728        if prev_handler.is_some() {
 729            panic!("registered a handler for the same message twice");
 730        }
 731        self
 732    }
 733
 734    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 735    where
 736        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 737        Fut: 'static + Send + Future<Output = Result<()>>,
 738        M: EnvelopedMessage,
 739    {
 740        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 741        self
 742    }
 743
 744    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 745    where
 746        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 747        Fut: Send + Future<Output = Result<()>>,
 748        M: RequestMessage,
 749    {
 750        let handler = Arc::new(handler);
 751        self.add_handler(move |envelope, session| {
 752            let receipt = envelope.receipt();
 753            let handler = handler.clone();
 754            async move {
 755                let peer = session.peer.clone();
 756                let responded = Arc::new(AtomicBool::default());
 757                let response = Response {
 758                    peer: peer.clone(),
 759                    responded: responded.clone(),
 760                    receipt,
 761                };
 762                match (handler)(envelope.payload, response, session).await {
 763                    Ok(()) => {
 764                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 765                            Ok(())
 766                        } else {
 767                            Err(anyhow!("handler did not send a response"))?
 768                        }
 769                    }
 770                    Err(error) => {
 771                        let proto_err = match &error {
 772                            Error::Internal(err) => err.to_proto(),
 773                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 774                        };
 775                        peer.respond_with_error(receipt, proto_err)?;
 776                        Err(error)
 777                    }
 778                }
 779            }
 780        })
 781    }
 782
 783    pub fn handle_connection(
 784        self: &Arc<Self>,
 785        connection: Connection,
 786        address: String,
 787        principal: Principal,
 788        zed_version: ZedVersion,
 789        release_channel: Option<String>,
 790        user_agent: Option<String>,
 791        geoip_country_code: Option<String>,
 792        system_id: Option<String>,
 793        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 794        executor: Executor,
 795        connection_guard: Option<ConnectionGuard>,
 796    ) -> impl Future<Output = ()> + use<> {
 797        let this = self.clone();
 798        let span = info_span!("handle connection", %address,
 799            connection_id=field::Empty,
 800            user_id=field::Empty,
 801            login=field::Empty,
 802            impersonator=field::Empty,
 803            user_agent=field::Empty,
 804            geoip_country_code=field::Empty,
 805            release_channel=field::Empty,
 806        );
 807        principal.update_span(&span);
 808        if let Some(user_agent) = user_agent {
 809            span.record("user_agent", user_agent);
 810        }
 811        if let Some(release_channel) = release_channel {
 812            span.record("release_channel", release_channel);
 813        }
 814
 815        if let Some(country_code) = geoip_country_code.as_ref() {
 816            span.record("geoip_country_code", country_code);
 817        }
 818
 819        let mut teardown = self.teardown.subscribe();
 820        async move {
 821            if *teardown.borrow() {
 822                tracing::error!("server is tearing down");
 823                return;
 824            }
 825
 826            let (connection_id, handle_io, mut incoming_rx) =
 827                this.peer.add_connection(connection, {
 828                    let executor = executor.clone();
 829                    move |duration| executor.sleep(duration)
 830                });
 831            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 832
 833            tracing::info!("connection opened");
 834
 835            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 836            let http_client = match ReqwestClient::user_agent(&user_agent) {
 837                Ok(http_client) => Arc::new(http_client),
 838                Err(error) => {
 839                    tracing::error!(?error, "failed to create HTTP client");
 840                    return;
 841                }
 842            };
 843
 844            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 845                |supermaven_admin_api_key| {
 846                    Arc::new(SupermavenAdminApi::new(
 847                        supermaven_admin_api_key.to_string(),
 848                        http_client.clone(),
 849                    ))
 850                },
 851            );
 852
 853            let session = Session {
 854                principal: principal.clone(),
 855                connection_id,
 856                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 857                peer: this.peer.clone(),
 858                connection_pool: this.connection_pool.clone(),
 859                app_state: this.app_state.clone(),
 860                geoip_country_code,
 861                system_id,
 862                _executor: executor.clone(),
 863                supermaven_client,
 864            };
 865
 866            if let Err(error) = this
 867                .send_initial_client_update(
 868                    connection_id,
 869                    zed_version,
 870                    send_connection_id,
 871                    &session,
 872                )
 873                .await
 874            {
 875                tracing::error!(?error, "failed to send initial client update");
 876                return;
 877            }
 878            drop(connection_guard);
 879
 880            let handle_io = handle_io.fuse();
 881            futures::pin_mut!(handle_io);
 882
 883            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 884            // This prevents deadlocks when e.g., client A performs a request to client B and
 885            // client B performs a request to client A. If both clients stop processing further
 886            // messages until their respective request completes, they won't have a chance to
 887            // respond to the other client's request and cause a deadlock.
 888            //
 889            // This arrangement ensures we will attempt to process earlier messages first, but fall
 890            // back to processing messages arrived later in the spirit of making progress.
 891            const MAX_CONCURRENT_HANDLERS: usize = 256;
 892            let mut foreground_message_handlers = FuturesUnordered::new();
 893            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 894            let get_concurrent_handlers = {
 895                let concurrent_handlers = concurrent_handlers.clone();
 896                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 897            };
 898            loop {
 899                let next_message = async {
 900                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 901                    let message = incoming_rx.next().await;
 902                    // Cache the concurrent_handlers here, so that we know what the
 903                    // queue looks like as each handler starts
 904                    (permit, message, get_concurrent_handlers())
 905                }
 906                .fuse();
 907                futures::pin_mut!(next_message);
 908                futures::select_biased! {
 909                    _ = teardown.changed().fuse() => return,
 910                    result = handle_io => {
 911                        if let Err(error) = result {
 912                            tracing::error!(?error, "error handling I/O");
 913                        }
 914                        break;
 915                    }
 916                    _ = foreground_message_handlers.next() => {}
 917                    next_message = next_message => {
 918                        let (permit, message, concurrent_handlers) = next_message;
 919                        if let Some(message) = message {
 920                            let type_name = message.payload_type_name();
 921                            // note: we copy all the fields from the parent span so we can query them in the logs.
 922                            // (https://github.com/tokio-rs/tracing/issues/2670).
 923                            let span = tracing::info_span!("receive message",
 924                                %connection_id,
 925                                %address,
 926                                type_name,
 927                                concurrent_handlers,
 928                                user_id=field::Empty,
 929                                login=field::Empty,
 930                                impersonator=field::Empty,
 931                                multi_lsp_query_request=field::Empty,
 932                                release_channel=field::Empty,
 933                                { TOTAL_DURATION_MS }=field::Empty,
 934                                { PROCESSING_DURATION_MS }=field::Empty,
 935                                { QUEUE_DURATION_MS }=field::Empty,
 936                                { HOST_WAITING_MS }=field::Empty
 937                            );
 938                            principal.update_span(&span);
 939                            let span_enter = span.enter();
 940                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 941                                let is_background = message.is_background();
 942                                let handle_message = (handler)(message, session.clone(), span.clone());
 943                                drop(span_enter);
 944
 945                                let handle_message = async move {
 946                                    handle_message.await;
 947                                    drop(permit);
 948                                }.instrument(span);
 949                                if is_background {
 950                                    executor.spawn_detached(handle_message);
 951                                } else {
 952                                    foreground_message_handlers.push(handle_message);
 953                                }
 954                            } else {
 955                                tracing::error!("no message handler");
 956                            }
 957                        } else {
 958                            tracing::info!("connection closed");
 959                            break;
 960                        }
 961                    }
 962                }
 963            }
 964
 965            drop(foreground_message_handlers);
 966            let concurrent_handlers = get_concurrent_handlers();
 967            tracing::info!(concurrent_handlers, "signing out");
 968            if let Err(error) = connection_lost(session, teardown, executor).await {
 969                tracing::error!(?error, "error signing out");
 970            }
 971        }
 972        .instrument(span)
 973    }
 974
 975    async fn send_initial_client_update(
 976        &self,
 977        connection_id: ConnectionId,
 978        zed_version: ZedVersion,
 979        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 980        session: &Session,
 981    ) -> Result<()> {
 982        self.peer.send(
 983            connection_id,
 984            proto::Hello {
 985                peer_id: Some(connection_id.into()),
 986            },
 987        )?;
 988        tracing::info!("sent hello message");
 989        if let Some(send_connection_id) = send_connection_id.take() {
 990            let _ = send_connection_id.send(connection_id);
 991        }
 992
 993        match &session.principal {
 994            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 995                if !user.connected_once {
 996                    self.peer.send(connection_id, proto::ShowContacts {})?;
 997                    self.app_state
 998                        .db
 999                        .set_user_connected_once(user.id, true)
1000                        .await?;
1001                }
1002
1003                update_user_plan(session).await?;
1004
1005                let contacts = self.app_state.db.get_contacts(user.id).await?;
1006
1007                {
1008                    let mut pool = self.connection_pool.lock();
1009                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1010                    self.peer.send(
1011                        connection_id,
1012                        build_initial_contacts_update(contacts, &pool),
1013                    )?;
1014                }
1015
1016                if should_auto_subscribe_to_channels(zed_version) {
1017                    subscribe_user_to_channels(user.id, session).await?;
1018                }
1019
1020                if let Some(incoming_call) =
1021                    self.app_state.db.incoming_call_for_user(user.id).await?
1022                {
1023                    self.peer.send(connection_id, incoming_call)?;
1024                }
1025
1026                update_user_contacts(user.id, session).await?;
1027            }
1028        }
1029
1030        Ok(())
1031    }
1032
1033    pub async fn invite_code_redeemed(
1034        self: &Arc<Self>,
1035        inviter_id: UserId,
1036        invitee_id: UserId,
1037    ) -> Result<()> {
1038        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1039            if let Some(code) = &user.invite_code {
1040                let pool = self.connection_pool.lock();
1041                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1042                for connection_id in pool.user_connection_ids(inviter_id) {
1043                    self.peer.send(
1044                        connection_id,
1045                        proto::UpdateContacts {
1046                            contacts: vec![invitee_contact.clone()],
1047                            ..Default::default()
1048                        },
1049                    )?;
1050                    self.peer.send(
1051                        connection_id,
1052                        proto::UpdateInviteInfo {
1053                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1054                            count: user.invite_count as u32,
1055                        },
1056                    )?;
1057                }
1058            }
1059        }
1060        Ok(())
1061    }
1062
1063    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1064        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1065            if let Some(invite_code) = &user.invite_code {
1066                let pool = self.connection_pool.lock();
1067                for connection_id in pool.user_connection_ids(user_id) {
1068                    self.peer.send(
1069                        connection_id,
1070                        proto::UpdateInviteInfo {
1071                            url: format!(
1072                                "{}{}",
1073                                self.app_state.config.invite_link_prefix, invite_code
1074                            ),
1075                            count: user.invite_count as u32,
1076                        },
1077                    )?;
1078                }
1079            }
1080        }
1081        Ok(())
1082    }
1083
1084    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1085        ServerSnapshot {
1086            connection_pool: ConnectionPoolGuard {
1087                guard: self.connection_pool.lock(),
1088                _not_send: PhantomData,
1089            },
1090            peer: &self.peer,
1091        }
1092    }
1093}
1094
1095impl Deref for ConnectionPoolGuard<'_> {
1096    type Target = ConnectionPool;
1097
1098    fn deref(&self) -> &Self::Target {
1099        &self.guard
1100    }
1101}
1102
1103impl DerefMut for ConnectionPoolGuard<'_> {
1104    fn deref_mut(&mut self) -> &mut Self::Target {
1105        &mut self.guard
1106    }
1107}
1108
1109impl Drop for ConnectionPoolGuard<'_> {
1110    fn drop(&mut self) {
1111        #[cfg(test)]
1112        self.check_invariants();
1113    }
1114}
1115
1116fn broadcast<F>(
1117    sender_id: Option<ConnectionId>,
1118    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1119    mut f: F,
1120) where
1121    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1122{
1123    for receiver_id in receiver_ids {
1124        if Some(receiver_id) != sender_id {
1125            if let Err(error) = f(receiver_id) {
1126                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1127            }
1128        }
1129    }
1130}
1131
1132pub struct ProtocolVersion(u32);
1133
1134impl Header for ProtocolVersion {
1135    fn name() -> &'static HeaderName {
1136        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1137        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1138    }
1139
1140    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1141    where
1142        Self: Sized,
1143        I: Iterator<Item = &'i axum::http::HeaderValue>,
1144    {
1145        let version = values
1146            .next()
1147            .ok_or_else(axum::headers::Error::invalid)?
1148            .to_str()
1149            .map_err(|_| axum::headers::Error::invalid())?
1150            .parse()
1151            .map_err(|_| axum::headers::Error::invalid())?;
1152        Ok(Self(version))
1153    }
1154
1155    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1156        values.extend([self.0.to_string().parse().unwrap()]);
1157    }
1158}
1159
1160pub struct AppVersionHeader(SemanticVersion);
1161impl Header for AppVersionHeader {
1162    fn name() -> &'static HeaderName {
1163        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1164        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1165    }
1166
1167    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1168    where
1169        Self: Sized,
1170        I: Iterator<Item = &'i axum::http::HeaderValue>,
1171    {
1172        let version = values
1173            .next()
1174            .ok_or_else(axum::headers::Error::invalid)?
1175            .to_str()
1176            .map_err(|_| axum::headers::Error::invalid())?
1177            .parse()
1178            .map_err(|_| axum::headers::Error::invalid())?;
1179        Ok(Self(version))
1180    }
1181
1182    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1183        values.extend([self.0.to_string().parse().unwrap()]);
1184    }
1185}
1186
1187#[derive(Debug)]
1188pub struct ReleaseChannelHeader(String);
1189
1190impl Header for ReleaseChannelHeader {
1191    fn name() -> &'static HeaderName {
1192        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1193        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1194    }
1195
1196    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1197    where
1198        Self: Sized,
1199        I: Iterator<Item = &'i axum::http::HeaderValue>,
1200    {
1201        Ok(Self(
1202            values
1203                .next()
1204                .ok_or_else(axum::headers::Error::invalid)?
1205                .to_str()
1206                .map_err(|_| axum::headers::Error::invalid())?
1207                .to_owned(),
1208        ))
1209    }
1210
1211    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1212        values.extend([self.0.parse().unwrap()]);
1213    }
1214}
1215
1216pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1217    Router::new()
1218        .route("/rpc", get(handle_websocket_request))
1219        .layer(
1220            ServiceBuilder::new()
1221                .layer(Extension(server.app_state.clone()))
1222                .layer(middleware::from_fn(auth::validate_header)),
1223        )
1224        .route("/metrics", get(handle_metrics))
1225        .layer(Extension(server))
1226}
1227
1228pub async fn handle_websocket_request(
1229    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1230    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1231    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1232    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1233    Extension(server): Extension<Arc<Server>>,
1234    Extension(principal): Extension<Principal>,
1235    user_agent: Option<TypedHeader<UserAgent>>,
1236    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1237    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1238    ws: WebSocketUpgrade,
1239) -> axum::response::Response {
1240    if protocol_version != rpc::PROTOCOL_VERSION {
1241        return (
1242            StatusCode::UPGRADE_REQUIRED,
1243            "client must be upgraded".to_string(),
1244        )
1245            .into_response();
1246    }
1247
1248    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1249        return (
1250            StatusCode::UPGRADE_REQUIRED,
1251            "no version header found".to_string(),
1252        )
1253            .into_response();
1254    };
1255
1256    let release_channel = release_channel_header.map(|header| header.0.0);
1257
1258    if !version.can_collaborate() {
1259        return (
1260            StatusCode::UPGRADE_REQUIRED,
1261            "client must be upgraded".to_string(),
1262        )
1263            .into_response();
1264    }
1265
1266    let socket_address = socket_address.to_string();
1267
1268    // Acquire connection guard before WebSocket upgrade
1269    let connection_guard = match ConnectionGuard::try_acquire() {
1270        Ok(guard) => guard,
1271        Err(()) => {
1272            return (
1273                StatusCode::SERVICE_UNAVAILABLE,
1274                "Too many concurrent connections",
1275            )
1276                .into_response();
1277        }
1278    };
1279
1280    ws.on_upgrade(move |socket| {
1281        let socket = socket
1282            .map_ok(to_tungstenite_message)
1283            .err_into()
1284            .with(|message| async move { to_axum_message(message) });
1285        let connection = Connection::new(Box::pin(socket));
1286        async move {
1287            server
1288                .handle_connection(
1289                    connection,
1290                    socket_address,
1291                    principal,
1292                    version,
1293                    release_channel,
1294                    user_agent.map(|header| header.to_string()),
1295                    country_code_header.map(|header| header.to_string()),
1296                    system_id_header.map(|header| header.to_string()),
1297                    None,
1298                    Executor::Production,
1299                    Some(connection_guard),
1300                )
1301                .await;
1302        }
1303    })
1304}
1305
1306pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1307    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1308    let connections_metric = CONNECTIONS_METRIC
1309        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1310
1311    let connections = server
1312        .connection_pool
1313        .lock()
1314        .connections()
1315        .filter(|connection| !connection.admin)
1316        .count();
1317    connections_metric.set(connections as _);
1318
1319    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1320    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1321        register_int_gauge!(
1322            "shared_projects",
1323            "number of open projects with one or more guests"
1324        )
1325        .unwrap()
1326    });
1327
1328    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1329    shared_projects_metric.set(shared_projects as _);
1330
1331    let encoder = prometheus::TextEncoder::new();
1332    let metric_families = prometheus::gather();
1333    let encoded_metrics = encoder
1334        .encode_to_string(&metric_families)
1335        .map_err(|err| anyhow!("{err}"))?;
1336    Ok(encoded_metrics)
1337}
1338
1339#[instrument(err, skip(executor))]
1340async fn connection_lost(
1341    session: Session,
1342    mut teardown: watch::Receiver<bool>,
1343    executor: Executor,
1344) -> Result<()> {
1345    session.peer.disconnect(session.connection_id);
1346    session
1347        .connection_pool()
1348        .await
1349        .remove_connection(session.connection_id)?;
1350
1351    session
1352        .db()
1353        .await
1354        .connection_lost(session.connection_id)
1355        .await
1356        .trace_err();
1357
1358    futures::select_biased! {
1359        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1360
1361            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1362            leave_room_for_session(&session, session.connection_id).await.trace_err();
1363            leave_channel_buffers_for_session(&session)
1364                .await
1365                .trace_err();
1366
1367            if !session
1368                .connection_pool()
1369                .await
1370                .is_user_online(session.user_id())
1371            {
1372                let db = session.db().await;
1373                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1374                    room_updated(&room, &session.peer);
1375                }
1376            }
1377
1378            update_user_contacts(session.user_id(), &session).await?;
1379        },
1380        _ = teardown.changed().fuse() => {}
1381    }
1382
1383    Ok(())
1384}
1385
1386/// Acknowledges a ping from a client, used to keep the connection alive.
1387async fn ping(
1388    _: proto::Ping,
1389    response: Response<proto::Ping>,
1390    _session: MessageContext,
1391) -> Result<()> {
1392    response.send(proto::Ack {})?;
1393    Ok(())
1394}
1395
1396/// Creates a new room for calling (outside of channels)
1397async fn create_room(
1398    _request: proto::CreateRoom,
1399    response: Response<proto::CreateRoom>,
1400    session: MessageContext,
1401) -> Result<()> {
1402    let livekit_room = nanoid::nanoid!(30);
1403
1404    let live_kit_connection_info = util::maybe!(async {
1405        let live_kit = session.app_state.livekit_client.as_ref();
1406        let live_kit = live_kit?;
1407        let user_id = session.user_id().to_string();
1408
1409        let token = live_kit
1410            .room_token(&livekit_room, &user_id.to_string())
1411            .trace_err()?;
1412
1413        Some(proto::LiveKitConnectionInfo {
1414            server_url: live_kit.url().into(),
1415            token,
1416            can_publish: true,
1417        })
1418    })
1419    .await;
1420
1421    let room = session
1422        .db()
1423        .await
1424        .create_room(session.user_id(), session.connection_id, &livekit_room)
1425        .await?;
1426
1427    response.send(proto::CreateRoomResponse {
1428        room: Some(room.clone()),
1429        live_kit_connection_info,
1430    })?;
1431
1432    update_user_contacts(session.user_id(), &session).await?;
1433    Ok(())
1434}
1435
1436/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1437async fn join_room(
1438    request: proto::JoinRoom,
1439    response: Response<proto::JoinRoom>,
1440    session: MessageContext,
1441) -> Result<()> {
1442    let room_id = RoomId::from_proto(request.id);
1443
1444    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1445
1446    if let Some(channel_id) = channel_id {
1447        return join_channel_internal(channel_id, Box::new(response), session).await;
1448    }
1449
1450    let joined_room = {
1451        let room = session
1452            .db()
1453            .await
1454            .join_room(room_id, session.user_id(), session.connection_id)
1455            .await?;
1456        room_updated(&room.room, &session.peer);
1457        room.into_inner()
1458    };
1459
1460    for connection_id in session
1461        .connection_pool()
1462        .await
1463        .user_connection_ids(session.user_id())
1464    {
1465        session
1466            .peer
1467            .send(
1468                connection_id,
1469                proto::CallCanceled {
1470                    room_id: room_id.to_proto(),
1471                },
1472            )
1473            .trace_err();
1474    }
1475
1476    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1477    {
1478        live_kit
1479            .room_token(
1480                &joined_room.room.livekit_room,
1481                &session.user_id().to_string(),
1482            )
1483            .trace_err()
1484            .map(|token| proto::LiveKitConnectionInfo {
1485                server_url: live_kit.url().into(),
1486                token,
1487                can_publish: true,
1488            })
1489    } else {
1490        None
1491    };
1492
1493    response.send(proto::JoinRoomResponse {
1494        room: Some(joined_room.room),
1495        channel_id: None,
1496        live_kit_connection_info,
1497    })?;
1498
1499    update_user_contacts(session.user_id(), &session).await?;
1500    Ok(())
1501}
1502
1503/// Rejoin room is used to reconnect to a room after connection errors.
1504async fn rejoin_room(
1505    request: proto::RejoinRoom,
1506    response: Response<proto::RejoinRoom>,
1507    session: MessageContext,
1508) -> Result<()> {
1509    let room;
1510    let channel;
1511    {
1512        let mut rejoined_room = session
1513            .db()
1514            .await
1515            .rejoin_room(request, session.user_id(), session.connection_id)
1516            .await?;
1517
1518        response.send(proto::RejoinRoomResponse {
1519            room: Some(rejoined_room.room.clone()),
1520            reshared_projects: rejoined_room
1521                .reshared_projects
1522                .iter()
1523                .map(|project| proto::ResharedProject {
1524                    id: project.id.to_proto(),
1525                    collaborators: project
1526                        .collaborators
1527                        .iter()
1528                        .map(|collaborator| collaborator.to_proto())
1529                        .collect(),
1530                })
1531                .collect(),
1532            rejoined_projects: rejoined_room
1533                .rejoined_projects
1534                .iter()
1535                .map(|rejoined_project| rejoined_project.to_proto())
1536                .collect(),
1537        })?;
1538        room_updated(&rejoined_room.room, &session.peer);
1539
1540        for project in &rejoined_room.reshared_projects {
1541            for collaborator in &project.collaborators {
1542                session
1543                    .peer
1544                    .send(
1545                        collaborator.connection_id,
1546                        proto::UpdateProjectCollaborator {
1547                            project_id: project.id.to_proto(),
1548                            old_peer_id: Some(project.old_connection_id.into()),
1549                            new_peer_id: Some(session.connection_id.into()),
1550                        },
1551                    )
1552                    .trace_err();
1553            }
1554
1555            broadcast(
1556                Some(session.connection_id),
1557                project
1558                    .collaborators
1559                    .iter()
1560                    .map(|collaborator| collaborator.connection_id),
1561                |connection_id| {
1562                    session.peer.forward_send(
1563                        session.connection_id,
1564                        connection_id,
1565                        proto::UpdateProject {
1566                            project_id: project.id.to_proto(),
1567                            worktrees: project.worktrees.clone(),
1568                        },
1569                    )
1570                },
1571            );
1572        }
1573
1574        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1575
1576        let rejoined_room = rejoined_room.into_inner();
1577
1578        room = rejoined_room.room;
1579        channel = rejoined_room.channel;
1580    }
1581
1582    if let Some(channel) = channel {
1583        channel_updated(
1584            &channel,
1585            &room,
1586            &session.peer,
1587            &*session.connection_pool().await,
1588        );
1589    }
1590
1591    update_user_contacts(session.user_id(), &session).await?;
1592    Ok(())
1593}
1594
1595fn notify_rejoined_projects(
1596    rejoined_projects: &mut Vec<RejoinedProject>,
1597    session: &Session,
1598) -> Result<()> {
1599    for project in rejoined_projects.iter() {
1600        for collaborator in &project.collaborators {
1601            session
1602                .peer
1603                .send(
1604                    collaborator.connection_id,
1605                    proto::UpdateProjectCollaborator {
1606                        project_id: project.id.to_proto(),
1607                        old_peer_id: Some(project.old_connection_id.into()),
1608                        new_peer_id: Some(session.connection_id.into()),
1609                    },
1610                )
1611                .trace_err();
1612        }
1613    }
1614
1615    for project in rejoined_projects {
1616        for worktree in mem::take(&mut project.worktrees) {
1617            // Stream this worktree's entries.
1618            let message = proto::UpdateWorktree {
1619                project_id: project.id.to_proto(),
1620                worktree_id: worktree.id,
1621                abs_path: worktree.abs_path.clone(),
1622                root_name: worktree.root_name,
1623                updated_entries: worktree.updated_entries,
1624                removed_entries: worktree.removed_entries,
1625                scan_id: worktree.scan_id,
1626                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1627                updated_repositories: worktree.updated_repositories,
1628                removed_repositories: worktree.removed_repositories,
1629            };
1630            for update in proto::split_worktree_update(message) {
1631                session.peer.send(session.connection_id, update)?;
1632            }
1633
1634            // Stream this worktree's diagnostics.
1635            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1636            if let Some(summary) = worktree_diagnostics.next() {
1637                let message = proto::UpdateDiagnosticSummary {
1638                    project_id: project.id.to_proto(),
1639                    worktree_id: worktree.id,
1640                    summary: Some(summary),
1641                    more_summaries: worktree_diagnostics.collect(),
1642                };
1643                session.peer.send(session.connection_id, message)?;
1644            }
1645
1646            for settings_file in worktree.settings_files {
1647                session.peer.send(
1648                    session.connection_id,
1649                    proto::UpdateWorktreeSettings {
1650                        project_id: project.id.to_proto(),
1651                        worktree_id: worktree.id,
1652                        path: settings_file.path,
1653                        content: Some(settings_file.content),
1654                        kind: Some(settings_file.kind.to_proto().into()),
1655                    },
1656                )?;
1657            }
1658        }
1659
1660        for repository in mem::take(&mut project.updated_repositories) {
1661            for update in split_repository_update(repository) {
1662                session.peer.send(session.connection_id, update)?;
1663            }
1664        }
1665
1666        for id in mem::take(&mut project.removed_repositories) {
1667            session.peer.send(
1668                session.connection_id,
1669                proto::RemoveRepository {
1670                    project_id: project.id.to_proto(),
1671                    id,
1672                },
1673            )?;
1674        }
1675    }
1676
1677    Ok(())
1678}
1679
1680/// leave room disconnects from the room.
1681async fn leave_room(
1682    _: proto::LeaveRoom,
1683    response: Response<proto::LeaveRoom>,
1684    session: MessageContext,
1685) -> Result<()> {
1686    leave_room_for_session(&session, session.connection_id).await?;
1687    response.send(proto::Ack {})?;
1688    Ok(())
1689}
1690
1691/// Updates the permissions of someone else in the room.
1692async fn set_room_participant_role(
1693    request: proto::SetRoomParticipantRole,
1694    response: Response<proto::SetRoomParticipantRole>,
1695    session: MessageContext,
1696) -> Result<()> {
1697    let user_id = UserId::from_proto(request.user_id);
1698    let role = ChannelRole::from(request.role());
1699
1700    let (livekit_room, can_publish) = {
1701        let room = session
1702            .db()
1703            .await
1704            .set_room_participant_role(
1705                session.user_id(),
1706                RoomId::from_proto(request.room_id),
1707                user_id,
1708                role,
1709            )
1710            .await?;
1711
1712        let livekit_room = room.livekit_room.clone();
1713        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1714        room_updated(&room, &session.peer);
1715        (livekit_room, can_publish)
1716    };
1717
1718    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1719        live_kit
1720            .update_participant(
1721                livekit_room.clone(),
1722                request.user_id.to_string(),
1723                livekit_api::proto::ParticipantPermission {
1724                    can_subscribe: true,
1725                    can_publish,
1726                    can_publish_data: can_publish,
1727                    hidden: false,
1728                    recorder: false,
1729                },
1730            )
1731            .await
1732            .trace_err();
1733    }
1734
1735    response.send(proto::Ack {})?;
1736    Ok(())
1737}
1738
1739/// Call someone else into the current room
1740async fn call(
1741    request: proto::Call,
1742    response: Response<proto::Call>,
1743    session: MessageContext,
1744) -> Result<()> {
1745    let room_id = RoomId::from_proto(request.room_id);
1746    let calling_user_id = session.user_id();
1747    let calling_connection_id = session.connection_id;
1748    let called_user_id = UserId::from_proto(request.called_user_id);
1749    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1750    if !session
1751        .db()
1752        .await
1753        .has_contact(calling_user_id, called_user_id)
1754        .await?
1755    {
1756        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1757    }
1758
1759    let incoming_call = {
1760        let (room, incoming_call) = &mut *session
1761            .db()
1762            .await
1763            .call(
1764                room_id,
1765                calling_user_id,
1766                calling_connection_id,
1767                called_user_id,
1768                initial_project_id,
1769            )
1770            .await?;
1771        room_updated(room, &session.peer);
1772        mem::take(incoming_call)
1773    };
1774    update_user_contacts(called_user_id, &session).await?;
1775
1776    let mut calls = session
1777        .connection_pool()
1778        .await
1779        .user_connection_ids(called_user_id)
1780        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1781        .collect::<FuturesUnordered<_>>();
1782
1783    while let Some(call_response) = calls.next().await {
1784        match call_response.as_ref() {
1785            Ok(_) => {
1786                response.send(proto::Ack {})?;
1787                return Ok(());
1788            }
1789            Err(_) => {
1790                call_response.trace_err();
1791            }
1792        }
1793    }
1794
1795    {
1796        let room = session
1797            .db()
1798            .await
1799            .call_failed(room_id, called_user_id)
1800            .await?;
1801        room_updated(&room, &session.peer);
1802    }
1803    update_user_contacts(called_user_id, &session).await?;
1804
1805    Err(anyhow!("failed to ring user"))?
1806}
1807
1808/// Cancel an outgoing call.
1809async fn cancel_call(
1810    request: proto::CancelCall,
1811    response: Response<proto::CancelCall>,
1812    session: MessageContext,
1813) -> Result<()> {
1814    let called_user_id = UserId::from_proto(request.called_user_id);
1815    let room_id = RoomId::from_proto(request.room_id);
1816    {
1817        let room = session
1818            .db()
1819            .await
1820            .cancel_call(room_id, session.connection_id, called_user_id)
1821            .await?;
1822        room_updated(&room, &session.peer);
1823    }
1824
1825    for connection_id in session
1826        .connection_pool()
1827        .await
1828        .user_connection_ids(called_user_id)
1829    {
1830        session
1831            .peer
1832            .send(
1833                connection_id,
1834                proto::CallCanceled {
1835                    room_id: room_id.to_proto(),
1836                },
1837            )
1838            .trace_err();
1839    }
1840    response.send(proto::Ack {})?;
1841
1842    update_user_contacts(called_user_id, &session).await?;
1843    Ok(())
1844}
1845
1846/// Decline an incoming call.
1847async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1848    let room_id = RoomId::from_proto(message.room_id);
1849    {
1850        let room = session
1851            .db()
1852            .await
1853            .decline_call(Some(room_id), session.user_id())
1854            .await?
1855            .context("declining call")?;
1856        room_updated(&room, &session.peer);
1857    }
1858
1859    for connection_id in session
1860        .connection_pool()
1861        .await
1862        .user_connection_ids(session.user_id())
1863    {
1864        session
1865            .peer
1866            .send(
1867                connection_id,
1868                proto::CallCanceled {
1869                    room_id: room_id.to_proto(),
1870                },
1871            )
1872            .trace_err();
1873    }
1874    update_user_contacts(session.user_id(), &session).await?;
1875    Ok(())
1876}
1877
1878/// Updates other participants in the room with your current location.
1879async fn update_participant_location(
1880    request: proto::UpdateParticipantLocation,
1881    response: Response<proto::UpdateParticipantLocation>,
1882    session: MessageContext,
1883) -> Result<()> {
1884    let room_id = RoomId::from_proto(request.room_id);
1885    let location = request.location.context("invalid location")?;
1886
1887    let db = session.db().await;
1888    let room = db
1889        .update_room_participant_location(room_id, session.connection_id, location)
1890        .await?;
1891
1892    room_updated(&room, &session.peer);
1893    response.send(proto::Ack {})?;
1894    Ok(())
1895}
1896
1897/// Share a project into the room.
1898async fn share_project(
1899    request: proto::ShareProject,
1900    response: Response<proto::ShareProject>,
1901    session: MessageContext,
1902) -> Result<()> {
1903    let (project_id, room) = &*session
1904        .db()
1905        .await
1906        .share_project(
1907            RoomId::from_proto(request.room_id),
1908            session.connection_id,
1909            &request.worktrees,
1910            request.is_ssh_project,
1911        )
1912        .await?;
1913    response.send(proto::ShareProjectResponse {
1914        project_id: project_id.to_proto(),
1915    })?;
1916    room_updated(room, &session.peer);
1917
1918    Ok(())
1919}
1920
1921/// Unshare a project from the room.
1922async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1923    let project_id = ProjectId::from_proto(message.project_id);
1924    unshare_project_internal(project_id, session.connection_id, &session).await
1925}
1926
1927async fn unshare_project_internal(
1928    project_id: ProjectId,
1929    connection_id: ConnectionId,
1930    session: &Session,
1931) -> Result<()> {
1932    let delete = {
1933        let room_guard = session
1934            .db()
1935            .await
1936            .unshare_project(project_id, connection_id)
1937            .await?;
1938
1939        let (delete, room, guest_connection_ids) = &*room_guard;
1940
1941        let message = proto::UnshareProject {
1942            project_id: project_id.to_proto(),
1943        };
1944
1945        broadcast(
1946            Some(connection_id),
1947            guest_connection_ids.iter().copied(),
1948            |conn_id| session.peer.send(conn_id, message.clone()),
1949        );
1950        if let Some(room) = room {
1951            room_updated(room, &session.peer);
1952        }
1953
1954        *delete
1955    };
1956
1957    if delete {
1958        let db = session.db().await;
1959        db.delete_project(project_id).await?;
1960    }
1961
1962    Ok(())
1963}
1964
1965/// Join someone elses shared project.
1966async fn join_project(
1967    request: proto::JoinProject,
1968    response: Response<proto::JoinProject>,
1969    session: MessageContext,
1970) -> Result<()> {
1971    let project_id = ProjectId::from_proto(request.project_id);
1972
1973    tracing::info!(%project_id, "join project");
1974
1975    let db = session.db().await;
1976    let (project, replica_id) = &mut *db
1977        .join_project(
1978            project_id,
1979            session.connection_id,
1980            session.user_id(),
1981            request.committer_name.clone(),
1982            request.committer_email.clone(),
1983        )
1984        .await?;
1985    drop(db);
1986    tracing::info!(%project_id, "join remote project");
1987    let collaborators = project
1988        .collaborators
1989        .iter()
1990        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1991        .map(|collaborator| collaborator.to_proto())
1992        .collect::<Vec<_>>();
1993    let project_id = project.id;
1994    let guest_user_id = session.user_id();
1995
1996    let worktrees = project
1997        .worktrees
1998        .iter()
1999        .map(|(id, worktree)| proto::WorktreeMetadata {
2000            id: *id,
2001            root_name: worktree.root_name.clone(),
2002            visible: worktree.visible,
2003            abs_path: worktree.abs_path.clone(),
2004        })
2005        .collect::<Vec<_>>();
2006
2007    let add_project_collaborator = proto::AddProjectCollaborator {
2008        project_id: project_id.to_proto(),
2009        collaborator: Some(proto::Collaborator {
2010            peer_id: Some(session.connection_id.into()),
2011            replica_id: replica_id.0 as u32,
2012            user_id: guest_user_id.to_proto(),
2013            is_host: false,
2014            committer_name: request.committer_name.clone(),
2015            committer_email: request.committer_email.clone(),
2016        }),
2017    };
2018
2019    for collaborator in &collaborators {
2020        session
2021            .peer
2022            .send(
2023                collaborator.peer_id.unwrap().into(),
2024                add_project_collaborator.clone(),
2025            )
2026            .trace_err();
2027    }
2028
2029    // First, we send the metadata associated with each worktree.
2030    let (language_servers, language_server_capabilities) = project
2031        .language_servers
2032        .clone()
2033        .into_iter()
2034        .map(|server| (server.server, server.capabilities))
2035        .unzip();
2036    response.send(proto::JoinProjectResponse {
2037        project_id: project.id.0 as u64,
2038        worktrees: worktrees.clone(),
2039        replica_id: replica_id.0 as u32,
2040        collaborators: collaborators.clone(),
2041        language_servers,
2042        language_server_capabilities,
2043        role: project.role.into(),
2044    })?;
2045
2046    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2047        // Stream this worktree's entries.
2048        let message = proto::UpdateWorktree {
2049            project_id: project_id.to_proto(),
2050            worktree_id,
2051            abs_path: worktree.abs_path.clone(),
2052            root_name: worktree.root_name,
2053            updated_entries: worktree.entries,
2054            removed_entries: Default::default(),
2055            scan_id: worktree.scan_id,
2056            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2057            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2058            removed_repositories: Default::default(),
2059        };
2060        for update in proto::split_worktree_update(message) {
2061            session.peer.send(session.connection_id, update.clone())?;
2062        }
2063
2064        // Stream this worktree's diagnostics.
2065        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2066        if let Some(summary) = worktree_diagnostics.next() {
2067            let message = proto::UpdateDiagnosticSummary {
2068                project_id: project.id.to_proto(),
2069                worktree_id: worktree.id,
2070                summary: Some(summary),
2071                more_summaries: worktree_diagnostics.collect(),
2072            };
2073            session.peer.send(session.connection_id, message)?;
2074        }
2075
2076        for settings_file in worktree.settings_files {
2077            session.peer.send(
2078                session.connection_id,
2079                proto::UpdateWorktreeSettings {
2080                    project_id: project_id.to_proto(),
2081                    worktree_id: worktree.id,
2082                    path: settings_file.path,
2083                    content: Some(settings_file.content),
2084                    kind: Some(settings_file.kind.to_proto() as i32),
2085                },
2086            )?;
2087        }
2088    }
2089
2090    for repository in mem::take(&mut project.repositories) {
2091        for update in split_repository_update(repository) {
2092            session.peer.send(session.connection_id, update)?;
2093        }
2094    }
2095
2096    for language_server in &project.language_servers {
2097        session.peer.send(
2098            session.connection_id,
2099            proto::UpdateLanguageServer {
2100                project_id: project_id.to_proto(),
2101                server_name: Some(language_server.server.name.clone()),
2102                language_server_id: language_server.server.id,
2103                variant: Some(
2104                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2105                        proto::LspDiskBasedDiagnosticsUpdated {},
2106                    ),
2107                ),
2108            },
2109        )?;
2110    }
2111
2112    Ok(())
2113}
2114
2115/// Leave someone elses shared project.
2116async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2117    let sender_id = session.connection_id;
2118    let project_id = ProjectId::from_proto(request.project_id);
2119    let db = session.db().await;
2120
2121    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2122    tracing::info!(
2123        %project_id,
2124        "leave project"
2125    );
2126
2127    project_left(project, &session);
2128    if let Some(room) = room {
2129        room_updated(room, &session.peer);
2130    }
2131
2132    Ok(())
2133}
2134
2135/// Updates other participants with changes to the project
2136async fn update_project(
2137    request: proto::UpdateProject,
2138    response: Response<proto::UpdateProject>,
2139    session: MessageContext,
2140) -> Result<()> {
2141    let project_id = ProjectId::from_proto(request.project_id);
2142    let (room, guest_connection_ids) = &*session
2143        .db()
2144        .await
2145        .update_project(project_id, session.connection_id, &request.worktrees)
2146        .await?;
2147    broadcast(
2148        Some(session.connection_id),
2149        guest_connection_ids.iter().copied(),
2150        |connection_id| {
2151            session
2152                .peer
2153                .forward_send(session.connection_id, connection_id, request.clone())
2154        },
2155    );
2156    if let Some(room) = room {
2157        room_updated(room, &session.peer);
2158    }
2159    response.send(proto::Ack {})?;
2160
2161    Ok(())
2162}
2163
2164/// Updates other participants with changes to the worktree
2165async fn update_worktree(
2166    request: proto::UpdateWorktree,
2167    response: Response<proto::UpdateWorktree>,
2168    session: MessageContext,
2169) -> Result<()> {
2170    let guest_connection_ids = session
2171        .db()
2172        .await
2173        .update_worktree(&request, session.connection_id)
2174        .await?;
2175
2176    broadcast(
2177        Some(session.connection_id),
2178        guest_connection_ids.iter().copied(),
2179        |connection_id| {
2180            session
2181                .peer
2182                .forward_send(session.connection_id, connection_id, request.clone())
2183        },
2184    );
2185    response.send(proto::Ack {})?;
2186    Ok(())
2187}
2188
2189async fn update_repository(
2190    request: proto::UpdateRepository,
2191    response: Response<proto::UpdateRepository>,
2192    session: MessageContext,
2193) -> Result<()> {
2194    let guest_connection_ids = session
2195        .db()
2196        .await
2197        .update_repository(&request, session.connection_id)
2198        .await?;
2199
2200    broadcast(
2201        Some(session.connection_id),
2202        guest_connection_ids.iter().copied(),
2203        |connection_id| {
2204            session
2205                .peer
2206                .forward_send(session.connection_id, connection_id, request.clone())
2207        },
2208    );
2209    response.send(proto::Ack {})?;
2210    Ok(())
2211}
2212
2213async fn remove_repository(
2214    request: proto::RemoveRepository,
2215    response: Response<proto::RemoveRepository>,
2216    session: MessageContext,
2217) -> Result<()> {
2218    let guest_connection_ids = session
2219        .db()
2220        .await
2221        .remove_repository(&request, session.connection_id)
2222        .await?;
2223
2224    broadcast(
2225        Some(session.connection_id),
2226        guest_connection_ids.iter().copied(),
2227        |connection_id| {
2228            session
2229                .peer
2230                .forward_send(session.connection_id, connection_id, request.clone())
2231        },
2232    );
2233    response.send(proto::Ack {})?;
2234    Ok(())
2235}
2236
2237/// Updates other participants with changes to the diagnostics
2238async fn update_diagnostic_summary(
2239    message: proto::UpdateDiagnosticSummary,
2240    session: MessageContext,
2241) -> Result<()> {
2242    let guest_connection_ids = session
2243        .db()
2244        .await
2245        .update_diagnostic_summary(&message, session.connection_id)
2246        .await?;
2247
2248    broadcast(
2249        Some(session.connection_id),
2250        guest_connection_ids.iter().copied(),
2251        |connection_id| {
2252            session
2253                .peer
2254                .forward_send(session.connection_id, connection_id, message.clone())
2255        },
2256    );
2257
2258    Ok(())
2259}
2260
2261/// Updates other participants with changes to the worktree settings
2262async fn update_worktree_settings(
2263    message: proto::UpdateWorktreeSettings,
2264    session: MessageContext,
2265) -> Result<()> {
2266    let guest_connection_ids = session
2267        .db()
2268        .await
2269        .update_worktree_settings(&message, session.connection_id)
2270        .await?;
2271
2272    broadcast(
2273        Some(session.connection_id),
2274        guest_connection_ids.iter().copied(),
2275        |connection_id| {
2276            session
2277                .peer
2278                .forward_send(session.connection_id, connection_id, message.clone())
2279        },
2280    );
2281
2282    Ok(())
2283}
2284
2285/// Notify other participants that a language server has started.
2286async fn start_language_server(
2287    request: proto::StartLanguageServer,
2288    session: MessageContext,
2289) -> Result<()> {
2290    let guest_connection_ids = session
2291        .db()
2292        .await
2293        .start_language_server(&request, session.connection_id)
2294        .await?;
2295
2296    broadcast(
2297        Some(session.connection_id),
2298        guest_connection_ids.iter().copied(),
2299        |connection_id| {
2300            session
2301                .peer
2302                .forward_send(session.connection_id, connection_id, request.clone())
2303        },
2304    );
2305    Ok(())
2306}
2307
2308/// Notify other participants that a language server has changed.
2309async fn update_language_server(
2310    request: proto::UpdateLanguageServer,
2311    session: MessageContext,
2312) -> Result<()> {
2313    let project_id = ProjectId::from_proto(request.project_id);
2314    let db = session.db().await;
2315
2316    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2317    {
2318        if let Some(capabilities) = update.capabilities.clone() {
2319            db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2320                .await?;
2321        }
2322    }
2323
2324    let project_connection_ids = db
2325        .project_connection_ids(project_id, session.connection_id, true)
2326        .await?;
2327    broadcast(
2328        Some(session.connection_id),
2329        project_connection_ids.iter().copied(),
2330        |connection_id| {
2331            session
2332                .peer
2333                .forward_send(session.connection_id, connection_id, request.clone())
2334        },
2335    );
2336    Ok(())
2337}
2338
2339/// forward a project request to the host. These requests should be read only
2340/// as guests are allowed to send them.
2341async fn forward_read_only_project_request<T>(
2342    request: T,
2343    response: Response<T>,
2344    session: MessageContext,
2345) -> Result<()>
2346where
2347    T: EntityMessage + RequestMessage,
2348{
2349    let project_id = ProjectId::from_proto(request.remote_entity_id());
2350    let host_connection_id = session
2351        .db()
2352        .await
2353        .host_for_read_only_project_request(project_id, session.connection_id)
2354        .await?;
2355    let payload = session.forward_request(host_connection_id, request).await?;
2356    response.send(payload)?;
2357    Ok(())
2358}
2359
2360/// forward a project request to the host. These requests are disallowed
2361/// for guests.
2362async fn forward_mutating_project_request<T>(
2363    request: T,
2364    response: Response<T>,
2365    session: MessageContext,
2366) -> Result<()>
2367where
2368    T: EntityMessage + RequestMessage,
2369{
2370    let project_id = ProjectId::from_proto(request.remote_entity_id());
2371
2372    let host_connection_id = session
2373        .db()
2374        .await
2375        .host_for_mutating_project_request(project_id, session.connection_id)
2376        .await?;
2377    let payload = session.forward_request(host_connection_id, request).await?;
2378    response.send(payload)?;
2379    Ok(())
2380}
2381
2382async fn multi_lsp_query(
2383    request: MultiLspQuery,
2384    response: Response<MultiLspQuery>,
2385    session: MessageContext,
2386) -> Result<()> {
2387    tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2388    tracing::info!("multi_lsp_query message received");
2389    forward_mutating_project_request(request, response, session).await
2390}
2391
2392/// Notify other participants that a new buffer has been created
2393async fn create_buffer_for_peer(
2394    request: proto::CreateBufferForPeer,
2395    session: MessageContext,
2396) -> Result<()> {
2397    session
2398        .db()
2399        .await
2400        .check_user_is_project_host(
2401            ProjectId::from_proto(request.project_id),
2402            session.connection_id,
2403        )
2404        .await?;
2405    let peer_id = request.peer_id.context("invalid peer id")?;
2406    session
2407        .peer
2408        .forward_send(session.connection_id, peer_id.into(), request)?;
2409    Ok(())
2410}
2411
2412/// Notify other participants that a buffer has been updated. This is
2413/// allowed for guests as long as the update is limited to selections.
2414async fn update_buffer(
2415    request: proto::UpdateBuffer,
2416    response: Response<proto::UpdateBuffer>,
2417    session: MessageContext,
2418) -> Result<()> {
2419    let project_id = ProjectId::from_proto(request.project_id);
2420    let mut capability = Capability::ReadOnly;
2421
2422    for op in request.operations.iter() {
2423        match op.variant {
2424            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2425            Some(_) => capability = Capability::ReadWrite,
2426        }
2427    }
2428
2429    let host = {
2430        let guard = session
2431            .db()
2432            .await
2433            .connections_for_buffer_update(project_id, session.connection_id, capability)
2434            .await?;
2435
2436        let (host, guests) = &*guard;
2437
2438        broadcast(
2439            Some(session.connection_id),
2440            guests.clone(),
2441            |connection_id| {
2442                session
2443                    .peer
2444                    .forward_send(session.connection_id, connection_id, request.clone())
2445            },
2446        );
2447
2448        *host
2449    };
2450
2451    if host != session.connection_id {
2452        session.forward_request(host, request.clone()).await?;
2453    }
2454
2455    response.send(proto::Ack {})?;
2456    Ok(())
2457}
2458
2459async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2460    let project_id = ProjectId::from_proto(message.project_id);
2461
2462    let operation = message.operation.as_ref().context("invalid operation")?;
2463    let capability = match operation.variant.as_ref() {
2464        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2465            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2466                match buffer_op.variant {
2467                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2468                        Capability::ReadOnly
2469                    }
2470                    _ => Capability::ReadWrite,
2471                }
2472            } else {
2473                Capability::ReadWrite
2474            }
2475        }
2476        Some(_) => Capability::ReadWrite,
2477        None => Capability::ReadOnly,
2478    };
2479
2480    let guard = session
2481        .db()
2482        .await
2483        .connections_for_buffer_update(project_id, session.connection_id, capability)
2484        .await?;
2485
2486    let (host, guests) = &*guard;
2487
2488    broadcast(
2489        Some(session.connection_id),
2490        guests.iter().chain([host]).copied(),
2491        |connection_id| {
2492            session
2493                .peer
2494                .forward_send(session.connection_id, connection_id, message.clone())
2495        },
2496    );
2497
2498    Ok(())
2499}
2500
2501/// Notify other participants that a project has been updated.
2502async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2503    request: T,
2504    session: MessageContext,
2505) -> Result<()> {
2506    let project_id = ProjectId::from_proto(request.remote_entity_id());
2507    let project_connection_ids = session
2508        .db()
2509        .await
2510        .project_connection_ids(project_id, session.connection_id, false)
2511        .await?;
2512
2513    broadcast(
2514        Some(session.connection_id),
2515        project_connection_ids.iter().copied(),
2516        |connection_id| {
2517            session
2518                .peer
2519                .forward_send(session.connection_id, connection_id, request.clone())
2520        },
2521    );
2522    Ok(())
2523}
2524
2525/// Start following another user in a call.
2526async fn follow(
2527    request: proto::Follow,
2528    response: Response<proto::Follow>,
2529    session: MessageContext,
2530) -> Result<()> {
2531    let room_id = RoomId::from_proto(request.room_id);
2532    let project_id = request.project_id.map(ProjectId::from_proto);
2533    let leader_id = request.leader_id.context("invalid leader id")?.into();
2534    let follower_id = session.connection_id;
2535
2536    session
2537        .db()
2538        .await
2539        .check_room_participants(room_id, leader_id, session.connection_id)
2540        .await?;
2541
2542    let response_payload = session.forward_request(leader_id, request).await?;
2543    response.send(response_payload)?;
2544
2545    if let Some(project_id) = project_id {
2546        let room = session
2547            .db()
2548            .await
2549            .follow(room_id, project_id, leader_id, follower_id)
2550            .await?;
2551        room_updated(&room, &session.peer);
2552    }
2553
2554    Ok(())
2555}
2556
2557/// Stop following another user in a call.
2558async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2559    let room_id = RoomId::from_proto(request.room_id);
2560    let project_id = request.project_id.map(ProjectId::from_proto);
2561    let leader_id = request.leader_id.context("invalid leader id")?.into();
2562    let follower_id = session.connection_id;
2563
2564    session
2565        .db()
2566        .await
2567        .check_room_participants(room_id, leader_id, session.connection_id)
2568        .await?;
2569
2570    session
2571        .peer
2572        .forward_send(session.connection_id, leader_id, request)?;
2573
2574    if let Some(project_id) = project_id {
2575        let room = session
2576            .db()
2577            .await
2578            .unfollow(room_id, project_id, leader_id, follower_id)
2579            .await?;
2580        room_updated(&room, &session.peer);
2581    }
2582
2583    Ok(())
2584}
2585
2586/// Notify everyone following you of your current location.
2587async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2588    let room_id = RoomId::from_proto(request.room_id);
2589    let database = session.db.lock().await;
2590
2591    let connection_ids = if let Some(project_id) = request.project_id {
2592        let project_id = ProjectId::from_proto(project_id);
2593        database
2594            .project_connection_ids(project_id, session.connection_id, true)
2595            .await?
2596    } else {
2597        database
2598            .room_connection_ids(room_id, session.connection_id)
2599            .await?
2600    };
2601
2602    // For now, don't send view update messages back to that view's current leader.
2603    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2604        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2605        _ => None,
2606    });
2607
2608    for connection_id in connection_ids.iter().cloned() {
2609        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2610            session
2611                .peer
2612                .forward_send(session.connection_id, connection_id, request.clone())?;
2613        }
2614    }
2615    Ok(())
2616}
2617
2618/// Get public data about users.
2619async fn get_users(
2620    request: proto::GetUsers,
2621    response: Response<proto::GetUsers>,
2622    session: MessageContext,
2623) -> Result<()> {
2624    let user_ids = request
2625        .user_ids
2626        .into_iter()
2627        .map(UserId::from_proto)
2628        .collect();
2629    let users = session
2630        .db()
2631        .await
2632        .get_users_by_ids(user_ids)
2633        .await?
2634        .into_iter()
2635        .map(|user| proto::User {
2636            id: user.id.to_proto(),
2637            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2638            github_login: user.github_login,
2639            name: user.name,
2640        })
2641        .collect();
2642    response.send(proto::UsersResponse { users })?;
2643    Ok(())
2644}
2645
2646/// Search for users (to invite) buy Github login
2647async fn fuzzy_search_users(
2648    request: proto::FuzzySearchUsers,
2649    response: Response<proto::FuzzySearchUsers>,
2650    session: MessageContext,
2651) -> Result<()> {
2652    let query = request.query;
2653    let users = match query.len() {
2654        0 => vec![],
2655        1 | 2 => session
2656            .db()
2657            .await
2658            .get_user_by_github_login(&query)
2659            .await?
2660            .into_iter()
2661            .collect(),
2662        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2663    };
2664    let users = users
2665        .into_iter()
2666        .filter(|user| user.id != session.user_id())
2667        .map(|user| proto::User {
2668            id: user.id.to_proto(),
2669            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2670            github_login: user.github_login,
2671            name: user.name,
2672        })
2673        .collect();
2674    response.send(proto::UsersResponse { users })?;
2675    Ok(())
2676}
2677
2678/// Send a contact request to another user.
2679async fn request_contact(
2680    request: proto::RequestContact,
2681    response: Response<proto::RequestContact>,
2682    session: MessageContext,
2683) -> Result<()> {
2684    let requester_id = session.user_id();
2685    let responder_id = UserId::from_proto(request.responder_id);
2686    if requester_id == responder_id {
2687        return Err(anyhow!("cannot add yourself as a contact"))?;
2688    }
2689
2690    let notifications = session
2691        .db()
2692        .await
2693        .send_contact_request(requester_id, responder_id)
2694        .await?;
2695
2696    // Update outgoing contact requests of requester
2697    let mut update = proto::UpdateContacts::default();
2698    update.outgoing_requests.push(responder_id.to_proto());
2699    for connection_id in session
2700        .connection_pool()
2701        .await
2702        .user_connection_ids(requester_id)
2703    {
2704        session.peer.send(connection_id, update.clone())?;
2705    }
2706
2707    // Update incoming contact requests of responder
2708    let mut update = proto::UpdateContacts::default();
2709    update
2710        .incoming_requests
2711        .push(proto::IncomingContactRequest {
2712            requester_id: requester_id.to_proto(),
2713        });
2714    let connection_pool = session.connection_pool().await;
2715    for connection_id in connection_pool.user_connection_ids(responder_id) {
2716        session.peer.send(connection_id, update.clone())?;
2717    }
2718
2719    send_notifications(&connection_pool, &session.peer, notifications);
2720
2721    response.send(proto::Ack {})?;
2722    Ok(())
2723}
2724
2725/// Accept or decline a contact request
2726async fn respond_to_contact_request(
2727    request: proto::RespondToContactRequest,
2728    response: Response<proto::RespondToContactRequest>,
2729    session: MessageContext,
2730) -> Result<()> {
2731    let responder_id = session.user_id();
2732    let requester_id = UserId::from_proto(request.requester_id);
2733    let db = session.db().await;
2734    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2735        db.dismiss_contact_notification(responder_id, requester_id)
2736            .await?;
2737    } else {
2738        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2739
2740        let notifications = db
2741            .respond_to_contact_request(responder_id, requester_id, accept)
2742            .await?;
2743        let requester_busy = db.is_user_busy(requester_id).await?;
2744        let responder_busy = db.is_user_busy(responder_id).await?;
2745
2746        let pool = session.connection_pool().await;
2747        // Update responder with new contact
2748        let mut update = proto::UpdateContacts::default();
2749        if accept {
2750            update
2751                .contacts
2752                .push(contact_for_user(requester_id, requester_busy, &pool));
2753        }
2754        update
2755            .remove_incoming_requests
2756            .push(requester_id.to_proto());
2757        for connection_id in pool.user_connection_ids(responder_id) {
2758            session.peer.send(connection_id, update.clone())?;
2759        }
2760
2761        // Update requester with new contact
2762        let mut update = proto::UpdateContacts::default();
2763        if accept {
2764            update
2765                .contacts
2766                .push(contact_for_user(responder_id, responder_busy, &pool));
2767        }
2768        update
2769            .remove_outgoing_requests
2770            .push(responder_id.to_proto());
2771
2772        for connection_id in pool.user_connection_ids(requester_id) {
2773            session.peer.send(connection_id, update.clone())?;
2774        }
2775
2776        send_notifications(&pool, &session.peer, notifications);
2777    }
2778
2779    response.send(proto::Ack {})?;
2780    Ok(())
2781}
2782
2783/// Remove a contact.
2784async fn remove_contact(
2785    request: proto::RemoveContact,
2786    response: Response<proto::RemoveContact>,
2787    session: MessageContext,
2788) -> Result<()> {
2789    let requester_id = session.user_id();
2790    let responder_id = UserId::from_proto(request.user_id);
2791    let db = session.db().await;
2792    let (contact_accepted, deleted_notification_id) =
2793        db.remove_contact(requester_id, responder_id).await?;
2794
2795    let pool = session.connection_pool().await;
2796    // Update outgoing contact requests of requester
2797    let mut update = proto::UpdateContacts::default();
2798    if contact_accepted {
2799        update.remove_contacts.push(responder_id.to_proto());
2800    } else {
2801        update
2802            .remove_outgoing_requests
2803            .push(responder_id.to_proto());
2804    }
2805    for connection_id in pool.user_connection_ids(requester_id) {
2806        session.peer.send(connection_id, update.clone())?;
2807    }
2808
2809    // Update incoming contact requests of responder
2810    let mut update = proto::UpdateContacts::default();
2811    if contact_accepted {
2812        update.remove_contacts.push(requester_id.to_proto());
2813    } else {
2814        update
2815            .remove_incoming_requests
2816            .push(requester_id.to_proto());
2817    }
2818    for connection_id in pool.user_connection_ids(responder_id) {
2819        session.peer.send(connection_id, update.clone())?;
2820        if let Some(notification_id) = deleted_notification_id {
2821            session.peer.send(
2822                connection_id,
2823                proto::DeleteNotification {
2824                    notification_id: notification_id.to_proto(),
2825                },
2826            )?;
2827        }
2828    }
2829
2830    response.send(proto::Ack {})?;
2831    Ok(())
2832}
2833
2834fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2835    version.0.minor() < 139
2836}
2837
2838async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2839    if is_staff {
2840        return Ok(proto::Plan::ZedPro);
2841    }
2842
2843    let subscription = db.get_active_billing_subscription(user_id).await?;
2844    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2845
2846    let plan = if let Some(subscription_kind) = subscription_kind {
2847        match subscription_kind {
2848            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2849            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2850            SubscriptionKind::ZedFree => proto::Plan::Free,
2851        }
2852    } else {
2853        proto::Plan::Free
2854    };
2855
2856    Ok(plan)
2857}
2858
2859async fn make_update_user_plan_message(
2860    user: &User,
2861    is_staff: bool,
2862    db: &Arc<Database>,
2863    llm_db: Option<Arc<LlmDatabase>>,
2864) -> Result<proto::UpdateUserPlan> {
2865    let feature_flags = db.get_user_flags(user.id).await?;
2866    let plan = current_plan(db, user.id, is_staff).await?;
2867    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2868    let billing_preferences = db.get_billing_preferences(user.id).await?;
2869
2870    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2871        let subscription = db.get_active_billing_subscription(user.id).await?;
2872
2873        let subscription_period =
2874            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2875
2876        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2877            llm_db
2878                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2879                .await?
2880        } else {
2881            None
2882        };
2883
2884        (subscription_period, usage)
2885    } else {
2886        (None, None)
2887    };
2888
2889    let bypass_account_age_check = feature_flags
2890        .iter()
2891        .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2892    let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2893        && !bypass_account_age_check
2894        && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2895
2896    Ok(proto::UpdateUserPlan {
2897        plan: plan.into(),
2898        trial_started_at: billing_customer
2899            .as_ref()
2900            .and_then(|billing_customer| billing_customer.trial_started_at)
2901            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2902        is_usage_based_billing_enabled: if is_staff {
2903            Some(true)
2904        } else {
2905            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2906        },
2907        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2908            proto::SubscriptionPeriod {
2909                started_at: started_at.timestamp() as u64,
2910                ended_at: ended_at.timestamp() as u64,
2911            }
2912        }),
2913        account_too_young: Some(account_too_young),
2914        has_overdue_invoices: billing_customer
2915            .map(|billing_customer| billing_customer.has_overdue_invoices),
2916        usage: Some(
2917            usage
2918                .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2919                .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2920        ),
2921    })
2922}
2923
2924fn model_requests_limit(
2925    plan: cloud_llm_client::Plan,
2926    feature_flags: &Vec<String>,
2927) -> cloud_llm_client::UsageLimit {
2928    match plan.model_requests_limit() {
2929        cloud_llm_client::UsageLimit::Limited(limit) => {
2930            let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2931                && feature_flags
2932                    .iter()
2933                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2934            {
2935                1_000
2936            } else {
2937                limit
2938            };
2939
2940            cloud_llm_client::UsageLimit::Limited(limit)
2941        }
2942        cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2943    }
2944}
2945
2946fn subscription_usage_to_proto(
2947    plan: proto::Plan,
2948    usage: crate::llm::db::subscription_usage::Model,
2949    feature_flags: &Vec<String>,
2950) -> proto::SubscriptionUsage {
2951    let plan = match plan {
2952        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2953        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2954        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2955    };
2956
2957    proto::SubscriptionUsage {
2958        model_requests_usage_amount: usage.model_requests as u32,
2959        model_requests_usage_limit: Some(proto::UsageLimit {
2960            variant: Some(match model_requests_limit(plan, feature_flags) {
2961                cloud_llm_client::UsageLimit::Limited(limit) => {
2962                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2963                        limit: limit as u32,
2964                    })
2965                }
2966                cloud_llm_client::UsageLimit::Unlimited => {
2967                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2968                }
2969            }),
2970        }),
2971        edit_predictions_usage_amount: usage.edit_predictions as u32,
2972        edit_predictions_usage_limit: Some(proto::UsageLimit {
2973            variant: Some(match plan.edit_predictions_limit() {
2974                cloud_llm_client::UsageLimit::Limited(limit) => {
2975                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2976                        limit: limit as u32,
2977                    })
2978                }
2979                cloud_llm_client::UsageLimit::Unlimited => {
2980                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2981                }
2982            }),
2983        }),
2984    }
2985}
2986
2987fn make_default_subscription_usage(
2988    plan: proto::Plan,
2989    feature_flags: &Vec<String>,
2990) -> proto::SubscriptionUsage {
2991    let plan = match plan {
2992        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2993        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2994        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2995    };
2996
2997    proto::SubscriptionUsage {
2998        model_requests_usage_amount: 0,
2999        model_requests_usage_limit: Some(proto::UsageLimit {
3000            variant: Some(match model_requests_limit(plan, feature_flags) {
3001                cloud_llm_client::UsageLimit::Limited(limit) => {
3002                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3003                        limit: limit as u32,
3004                    })
3005                }
3006                cloud_llm_client::UsageLimit::Unlimited => {
3007                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3008                }
3009            }),
3010        }),
3011        edit_predictions_usage_amount: 0,
3012        edit_predictions_usage_limit: Some(proto::UsageLimit {
3013            variant: Some(match plan.edit_predictions_limit() {
3014                cloud_llm_client::UsageLimit::Limited(limit) => {
3015                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3016                        limit: limit as u32,
3017                    })
3018                }
3019                cloud_llm_client::UsageLimit::Unlimited => {
3020                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3021                }
3022            }),
3023        }),
3024    }
3025}
3026
3027async fn update_user_plan(session: &Session) -> Result<()> {
3028    let db = session.db().await;
3029
3030    let update_user_plan = make_update_user_plan_message(
3031        session.principal.user(),
3032        session.is_staff(),
3033        &db.0,
3034        session.app_state.llm_db.clone(),
3035    )
3036    .await?;
3037
3038    session
3039        .peer
3040        .send(session.connection_id, update_user_plan)
3041        .trace_err();
3042
3043    Ok(())
3044}
3045
3046async fn subscribe_to_channels(
3047    _: proto::SubscribeToChannels,
3048    session: MessageContext,
3049) -> Result<()> {
3050    subscribe_user_to_channels(session.user_id(), &session).await?;
3051    Ok(())
3052}
3053
3054async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3055    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3056    let mut pool = session.connection_pool().await;
3057    for membership in &channels_for_user.channel_memberships {
3058        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3059    }
3060    session.peer.send(
3061        session.connection_id,
3062        build_update_user_channels(&channels_for_user),
3063    )?;
3064    session.peer.send(
3065        session.connection_id,
3066        build_channels_update(channels_for_user),
3067    )?;
3068    Ok(())
3069}
3070
3071/// Creates a new channel.
3072async fn create_channel(
3073    request: proto::CreateChannel,
3074    response: Response<proto::CreateChannel>,
3075    session: MessageContext,
3076) -> Result<()> {
3077    let db = session.db().await;
3078
3079    let parent_id = request.parent_id.map(ChannelId::from_proto);
3080    let (channel, membership) = db
3081        .create_channel(&request.name, parent_id, session.user_id())
3082        .await?;
3083
3084    let root_id = channel.root_id();
3085    let channel = Channel::from_model(channel);
3086
3087    response.send(proto::CreateChannelResponse {
3088        channel: Some(channel.to_proto()),
3089        parent_id: request.parent_id,
3090    })?;
3091
3092    let mut connection_pool = session.connection_pool().await;
3093    if let Some(membership) = membership {
3094        connection_pool.subscribe_to_channel(
3095            membership.user_id,
3096            membership.channel_id,
3097            membership.role,
3098        );
3099        let update = proto::UpdateUserChannels {
3100            channel_memberships: vec![proto::ChannelMembership {
3101                channel_id: membership.channel_id.to_proto(),
3102                role: membership.role.into(),
3103            }],
3104            ..Default::default()
3105        };
3106        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3107            session.peer.send(connection_id, update.clone())?;
3108        }
3109    }
3110
3111    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3112        if !role.can_see_channel(channel.visibility) {
3113            continue;
3114        }
3115
3116        let update = proto::UpdateChannels {
3117            channels: vec![channel.to_proto()],
3118            ..Default::default()
3119        };
3120        session.peer.send(connection_id, update.clone())?;
3121    }
3122
3123    Ok(())
3124}
3125
3126/// Delete a channel
3127async fn delete_channel(
3128    request: proto::DeleteChannel,
3129    response: Response<proto::DeleteChannel>,
3130    session: MessageContext,
3131) -> Result<()> {
3132    let db = session.db().await;
3133
3134    let channel_id = request.channel_id;
3135    let (root_channel, removed_channels) = db
3136        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3137        .await?;
3138    response.send(proto::Ack {})?;
3139
3140    // Notify members of removed channels
3141    let mut update = proto::UpdateChannels::default();
3142    update
3143        .delete_channels
3144        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3145
3146    let connection_pool = session.connection_pool().await;
3147    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3148        session.peer.send(connection_id, update.clone())?;
3149    }
3150
3151    Ok(())
3152}
3153
3154/// Invite someone to join a channel.
3155async fn invite_channel_member(
3156    request: proto::InviteChannelMember,
3157    response: Response<proto::InviteChannelMember>,
3158    session: MessageContext,
3159) -> Result<()> {
3160    let db = session.db().await;
3161    let channel_id = ChannelId::from_proto(request.channel_id);
3162    let invitee_id = UserId::from_proto(request.user_id);
3163    let InviteMemberResult {
3164        channel,
3165        notifications,
3166    } = db
3167        .invite_channel_member(
3168            channel_id,
3169            invitee_id,
3170            session.user_id(),
3171            request.role().into(),
3172        )
3173        .await?;
3174
3175    let update = proto::UpdateChannels {
3176        channel_invitations: vec![channel.to_proto()],
3177        ..Default::default()
3178    };
3179
3180    let connection_pool = session.connection_pool().await;
3181    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3182        session.peer.send(connection_id, update.clone())?;
3183    }
3184
3185    send_notifications(&connection_pool, &session.peer, notifications);
3186
3187    response.send(proto::Ack {})?;
3188    Ok(())
3189}
3190
3191/// remove someone from a channel
3192async fn remove_channel_member(
3193    request: proto::RemoveChannelMember,
3194    response: Response<proto::RemoveChannelMember>,
3195    session: MessageContext,
3196) -> Result<()> {
3197    let db = session.db().await;
3198    let channel_id = ChannelId::from_proto(request.channel_id);
3199    let member_id = UserId::from_proto(request.user_id);
3200
3201    let RemoveChannelMemberResult {
3202        membership_update,
3203        notification_id,
3204    } = db
3205        .remove_channel_member(channel_id, member_id, session.user_id())
3206        .await?;
3207
3208    let mut connection_pool = session.connection_pool().await;
3209    notify_membership_updated(
3210        &mut connection_pool,
3211        membership_update,
3212        member_id,
3213        &session.peer,
3214    );
3215    for connection_id in connection_pool.user_connection_ids(member_id) {
3216        if let Some(notification_id) = notification_id {
3217            session
3218                .peer
3219                .send(
3220                    connection_id,
3221                    proto::DeleteNotification {
3222                        notification_id: notification_id.to_proto(),
3223                    },
3224                )
3225                .trace_err();
3226        }
3227    }
3228
3229    response.send(proto::Ack {})?;
3230    Ok(())
3231}
3232
3233/// Toggle the channel between public and private.
3234/// Care is taken to maintain the invariant that public channels only descend from public channels,
3235/// (though members-only channels can appear at any point in the hierarchy).
3236async fn set_channel_visibility(
3237    request: proto::SetChannelVisibility,
3238    response: Response<proto::SetChannelVisibility>,
3239    session: MessageContext,
3240) -> Result<()> {
3241    let db = session.db().await;
3242    let channel_id = ChannelId::from_proto(request.channel_id);
3243    let visibility = request.visibility().into();
3244
3245    let channel_model = db
3246        .set_channel_visibility(channel_id, visibility, session.user_id())
3247        .await?;
3248    let root_id = channel_model.root_id();
3249    let channel = Channel::from_model(channel_model);
3250
3251    let mut connection_pool = session.connection_pool().await;
3252    for (user_id, role) in connection_pool
3253        .channel_user_ids(root_id)
3254        .collect::<Vec<_>>()
3255        .into_iter()
3256    {
3257        let update = if role.can_see_channel(channel.visibility) {
3258            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3259            proto::UpdateChannels {
3260                channels: vec![channel.to_proto()],
3261                ..Default::default()
3262            }
3263        } else {
3264            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3265            proto::UpdateChannels {
3266                delete_channels: vec![channel.id.to_proto()],
3267                ..Default::default()
3268            }
3269        };
3270
3271        for connection_id in connection_pool.user_connection_ids(user_id) {
3272            session.peer.send(connection_id, update.clone())?;
3273        }
3274    }
3275
3276    response.send(proto::Ack {})?;
3277    Ok(())
3278}
3279
3280/// Alter the role for a user in the channel.
3281async fn set_channel_member_role(
3282    request: proto::SetChannelMemberRole,
3283    response: Response<proto::SetChannelMemberRole>,
3284    session: MessageContext,
3285) -> Result<()> {
3286    let db = session.db().await;
3287    let channel_id = ChannelId::from_proto(request.channel_id);
3288    let member_id = UserId::from_proto(request.user_id);
3289    let result = db
3290        .set_channel_member_role(
3291            channel_id,
3292            session.user_id(),
3293            member_id,
3294            request.role().into(),
3295        )
3296        .await?;
3297
3298    match result {
3299        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3300            let mut connection_pool = session.connection_pool().await;
3301            notify_membership_updated(
3302                &mut connection_pool,
3303                membership_update,
3304                member_id,
3305                &session.peer,
3306            )
3307        }
3308        db::SetMemberRoleResult::InviteUpdated(channel) => {
3309            let update = proto::UpdateChannels {
3310                channel_invitations: vec![channel.to_proto()],
3311                ..Default::default()
3312            };
3313
3314            for connection_id in session
3315                .connection_pool()
3316                .await
3317                .user_connection_ids(member_id)
3318            {
3319                session.peer.send(connection_id, update.clone())?;
3320            }
3321        }
3322    }
3323
3324    response.send(proto::Ack {})?;
3325    Ok(())
3326}
3327
3328/// Change the name of a channel
3329async fn rename_channel(
3330    request: proto::RenameChannel,
3331    response: Response<proto::RenameChannel>,
3332    session: MessageContext,
3333) -> Result<()> {
3334    let db = session.db().await;
3335    let channel_id = ChannelId::from_proto(request.channel_id);
3336    let channel_model = db
3337        .rename_channel(channel_id, session.user_id(), &request.name)
3338        .await?;
3339    let root_id = channel_model.root_id();
3340    let channel = Channel::from_model(channel_model);
3341
3342    response.send(proto::RenameChannelResponse {
3343        channel: Some(channel.to_proto()),
3344    })?;
3345
3346    let connection_pool = session.connection_pool().await;
3347    let update = proto::UpdateChannels {
3348        channels: vec![channel.to_proto()],
3349        ..Default::default()
3350    };
3351    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3352        if role.can_see_channel(channel.visibility) {
3353            session.peer.send(connection_id, update.clone())?;
3354        }
3355    }
3356
3357    Ok(())
3358}
3359
3360/// Move a channel to a new parent.
3361async fn move_channel(
3362    request: proto::MoveChannel,
3363    response: Response<proto::MoveChannel>,
3364    session: MessageContext,
3365) -> Result<()> {
3366    let channel_id = ChannelId::from_proto(request.channel_id);
3367    let to = ChannelId::from_proto(request.to);
3368
3369    let (root_id, channels) = session
3370        .db()
3371        .await
3372        .move_channel(channel_id, to, session.user_id())
3373        .await?;
3374
3375    let connection_pool = session.connection_pool().await;
3376    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3377        let channels = channels
3378            .iter()
3379            .filter_map(|channel| {
3380                if role.can_see_channel(channel.visibility) {
3381                    Some(channel.to_proto())
3382                } else {
3383                    None
3384                }
3385            })
3386            .collect::<Vec<_>>();
3387        if channels.is_empty() {
3388            continue;
3389        }
3390
3391        let update = proto::UpdateChannels {
3392            channels,
3393            ..Default::default()
3394        };
3395
3396        session.peer.send(connection_id, update.clone())?;
3397    }
3398
3399    response.send(Ack {})?;
3400    Ok(())
3401}
3402
3403async fn reorder_channel(
3404    request: proto::ReorderChannel,
3405    response: Response<proto::ReorderChannel>,
3406    session: MessageContext,
3407) -> Result<()> {
3408    let channel_id = ChannelId::from_proto(request.channel_id);
3409    let direction = request.direction();
3410
3411    let updated_channels = session
3412        .db()
3413        .await
3414        .reorder_channel(channel_id, direction, session.user_id())
3415        .await?;
3416
3417    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3418        let connection_pool = session.connection_pool().await;
3419        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3420            let channels = updated_channels
3421                .iter()
3422                .filter_map(|channel| {
3423                    if role.can_see_channel(channel.visibility) {
3424                        Some(channel.to_proto())
3425                    } else {
3426                        None
3427                    }
3428                })
3429                .collect::<Vec<_>>();
3430
3431            if channels.is_empty() {
3432                continue;
3433            }
3434
3435            let update = proto::UpdateChannels {
3436                channels,
3437                ..Default::default()
3438            };
3439
3440            session.peer.send(connection_id, update.clone())?;
3441        }
3442    }
3443
3444    response.send(Ack {})?;
3445    Ok(())
3446}
3447
3448/// Get the list of channel members
3449async fn get_channel_members(
3450    request: proto::GetChannelMembers,
3451    response: Response<proto::GetChannelMembers>,
3452    session: MessageContext,
3453) -> Result<()> {
3454    let db = session.db().await;
3455    let channel_id = ChannelId::from_proto(request.channel_id);
3456    let limit = if request.limit == 0 {
3457        u16::MAX as u64
3458    } else {
3459        request.limit
3460    };
3461    let (members, users) = db
3462        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3463        .await?;
3464    response.send(proto::GetChannelMembersResponse { members, users })?;
3465    Ok(())
3466}
3467
3468/// Accept or decline a channel invitation.
3469async fn respond_to_channel_invite(
3470    request: proto::RespondToChannelInvite,
3471    response: Response<proto::RespondToChannelInvite>,
3472    session: MessageContext,
3473) -> Result<()> {
3474    let db = session.db().await;
3475    let channel_id = ChannelId::from_proto(request.channel_id);
3476    let RespondToChannelInvite {
3477        membership_update,
3478        notifications,
3479    } = db
3480        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3481        .await?;
3482
3483    let mut connection_pool = session.connection_pool().await;
3484    if let Some(membership_update) = membership_update {
3485        notify_membership_updated(
3486            &mut connection_pool,
3487            membership_update,
3488            session.user_id(),
3489            &session.peer,
3490        );
3491    } else {
3492        let update = proto::UpdateChannels {
3493            remove_channel_invitations: vec![channel_id.to_proto()],
3494            ..Default::default()
3495        };
3496
3497        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3498            session.peer.send(connection_id, update.clone())?;
3499        }
3500    };
3501
3502    send_notifications(&connection_pool, &session.peer, notifications);
3503
3504    response.send(proto::Ack {})?;
3505
3506    Ok(())
3507}
3508
3509/// Join the channels' room
3510async fn join_channel(
3511    request: proto::JoinChannel,
3512    response: Response<proto::JoinChannel>,
3513    session: MessageContext,
3514) -> Result<()> {
3515    let channel_id = ChannelId::from_proto(request.channel_id);
3516    join_channel_internal(channel_id, Box::new(response), session).await
3517}
3518
3519trait JoinChannelInternalResponse {
3520    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3521}
3522impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3523    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3524        Response::<proto::JoinChannel>::send(self, result)
3525    }
3526}
3527impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3528    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3529        Response::<proto::JoinRoom>::send(self, result)
3530    }
3531}
3532
3533async fn join_channel_internal(
3534    channel_id: ChannelId,
3535    response: Box<impl JoinChannelInternalResponse>,
3536    session: MessageContext,
3537) -> Result<()> {
3538    let joined_room = {
3539        let mut db = session.db().await;
3540        // If zed quits without leaving the room, and the user re-opens zed before the
3541        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3542        // room they were in.
3543        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3544            tracing::info!(
3545                stale_connection_id = %connection,
3546                "cleaning up stale connection",
3547            );
3548            drop(db);
3549            leave_room_for_session(&session, connection).await?;
3550            db = session.db().await;
3551        }
3552
3553        let (joined_room, membership_updated, role) = db
3554            .join_channel(channel_id, session.user_id(), session.connection_id)
3555            .await?;
3556
3557        let live_kit_connection_info =
3558            session
3559                .app_state
3560                .livekit_client
3561                .as_ref()
3562                .and_then(|live_kit| {
3563                    let (can_publish, token) = if role == ChannelRole::Guest {
3564                        (
3565                            false,
3566                            live_kit
3567                                .guest_token(
3568                                    &joined_room.room.livekit_room,
3569                                    &session.user_id().to_string(),
3570                                )
3571                                .trace_err()?,
3572                        )
3573                    } else {
3574                        (
3575                            true,
3576                            live_kit
3577                                .room_token(
3578                                    &joined_room.room.livekit_room,
3579                                    &session.user_id().to_string(),
3580                                )
3581                                .trace_err()?,
3582                        )
3583                    };
3584
3585                    Some(LiveKitConnectionInfo {
3586                        server_url: live_kit.url().into(),
3587                        token,
3588                        can_publish,
3589                    })
3590                });
3591
3592        response.send(proto::JoinRoomResponse {
3593            room: Some(joined_room.room.clone()),
3594            channel_id: joined_room
3595                .channel
3596                .as_ref()
3597                .map(|channel| channel.id.to_proto()),
3598            live_kit_connection_info,
3599        })?;
3600
3601        let mut connection_pool = session.connection_pool().await;
3602        if let Some(membership_updated) = membership_updated {
3603            notify_membership_updated(
3604                &mut connection_pool,
3605                membership_updated,
3606                session.user_id(),
3607                &session.peer,
3608            );
3609        }
3610
3611        room_updated(&joined_room.room, &session.peer);
3612
3613        joined_room
3614    };
3615
3616    channel_updated(
3617        &joined_room.channel.context("channel not returned")?,
3618        &joined_room.room,
3619        &session.peer,
3620        &*session.connection_pool().await,
3621    );
3622
3623    update_user_contacts(session.user_id(), &session).await?;
3624    Ok(())
3625}
3626
3627/// Start editing the channel notes
3628async fn join_channel_buffer(
3629    request: proto::JoinChannelBuffer,
3630    response: Response<proto::JoinChannelBuffer>,
3631    session: MessageContext,
3632) -> Result<()> {
3633    let db = session.db().await;
3634    let channel_id = ChannelId::from_proto(request.channel_id);
3635
3636    let open_response = db
3637        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3638        .await?;
3639
3640    let collaborators = open_response.collaborators.clone();
3641    response.send(open_response)?;
3642
3643    let update = UpdateChannelBufferCollaborators {
3644        channel_id: channel_id.to_proto(),
3645        collaborators: collaborators.clone(),
3646    };
3647    channel_buffer_updated(
3648        session.connection_id,
3649        collaborators
3650            .iter()
3651            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3652        &update,
3653        &session.peer,
3654    );
3655
3656    Ok(())
3657}
3658
3659/// Edit the channel notes
3660async fn update_channel_buffer(
3661    request: proto::UpdateChannelBuffer,
3662    session: MessageContext,
3663) -> Result<()> {
3664    let db = session.db().await;
3665    let channel_id = ChannelId::from_proto(request.channel_id);
3666
3667    let (collaborators, epoch, version) = db
3668        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3669        .await?;
3670
3671    channel_buffer_updated(
3672        session.connection_id,
3673        collaborators.clone(),
3674        &proto::UpdateChannelBuffer {
3675            channel_id: channel_id.to_proto(),
3676            operations: request.operations,
3677        },
3678        &session.peer,
3679    );
3680
3681    let pool = &*session.connection_pool().await;
3682
3683    let non_collaborators =
3684        pool.channel_connection_ids(channel_id)
3685            .filter_map(|(connection_id, _)| {
3686                if collaborators.contains(&connection_id) {
3687                    None
3688                } else {
3689                    Some(connection_id)
3690                }
3691            });
3692
3693    broadcast(None, non_collaborators, |peer_id| {
3694        session.peer.send(
3695            peer_id,
3696            proto::UpdateChannels {
3697                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3698                    channel_id: channel_id.to_proto(),
3699                    epoch: epoch as u64,
3700                    version: version.clone(),
3701                }],
3702                ..Default::default()
3703            },
3704        )
3705    });
3706
3707    Ok(())
3708}
3709
3710/// Rejoin the channel notes after a connection blip
3711async fn rejoin_channel_buffers(
3712    request: proto::RejoinChannelBuffers,
3713    response: Response<proto::RejoinChannelBuffers>,
3714    session: MessageContext,
3715) -> Result<()> {
3716    let db = session.db().await;
3717    let buffers = db
3718        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3719        .await?;
3720
3721    for rejoined_buffer in &buffers {
3722        let collaborators_to_notify = rejoined_buffer
3723            .buffer
3724            .collaborators
3725            .iter()
3726            .filter_map(|c| Some(c.peer_id?.into()));
3727        channel_buffer_updated(
3728            session.connection_id,
3729            collaborators_to_notify,
3730            &proto::UpdateChannelBufferCollaborators {
3731                channel_id: rejoined_buffer.buffer.channel_id,
3732                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3733            },
3734            &session.peer,
3735        );
3736    }
3737
3738    response.send(proto::RejoinChannelBuffersResponse {
3739        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3740    })?;
3741
3742    Ok(())
3743}
3744
3745/// Stop editing the channel notes
3746async fn leave_channel_buffer(
3747    request: proto::LeaveChannelBuffer,
3748    response: Response<proto::LeaveChannelBuffer>,
3749    session: MessageContext,
3750) -> Result<()> {
3751    let db = session.db().await;
3752    let channel_id = ChannelId::from_proto(request.channel_id);
3753
3754    let left_buffer = db
3755        .leave_channel_buffer(channel_id, session.connection_id)
3756        .await?;
3757
3758    response.send(Ack {})?;
3759
3760    channel_buffer_updated(
3761        session.connection_id,
3762        left_buffer.connections,
3763        &proto::UpdateChannelBufferCollaborators {
3764            channel_id: channel_id.to_proto(),
3765            collaborators: left_buffer.collaborators,
3766        },
3767        &session.peer,
3768    );
3769
3770    Ok(())
3771}
3772
3773fn channel_buffer_updated<T: EnvelopedMessage>(
3774    sender_id: ConnectionId,
3775    collaborators: impl IntoIterator<Item = ConnectionId>,
3776    message: &T,
3777    peer: &Peer,
3778) {
3779    broadcast(Some(sender_id), collaborators, |peer_id| {
3780        peer.send(peer_id, message.clone())
3781    });
3782}
3783
3784fn send_notifications(
3785    connection_pool: &ConnectionPool,
3786    peer: &Peer,
3787    notifications: db::NotificationBatch,
3788) {
3789    for (user_id, notification) in notifications {
3790        for connection_id in connection_pool.user_connection_ids(user_id) {
3791            if let Err(error) = peer.send(
3792                connection_id,
3793                proto::AddNotification {
3794                    notification: Some(notification.clone()),
3795                },
3796            ) {
3797                tracing::error!(
3798                    "failed to send notification to {:?} {}",
3799                    connection_id,
3800                    error
3801                );
3802            }
3803        }
3804    }
3805}
3806
3807/// Send a message to the channel
3808async fn send_channel_message(
3809    request: proto::SendChannelMessage,
3810    response: Response<proto::SendChannelMessage>,
3811    session: MessageContext,
3812) -> Result<()> {
3813    // Validate the message body.
3814    let body = request.body.trim().to_string();
3815    if body.len() > MAX_MESSAGE_LEN {
3816        return Err(anyhow!("message is too long"))?;
3817    }
3818    if body.is_empty() {
3819        return Err(anyhow!("message can't be blank"))?;
3820    }
3821
3822    // TODO: adjust mentions if body is trimmed
3823
3824    let timestamp = OffsetDateTime::now_utc();
3825    let nonce = request.nonce.context("nonce can't be blank")?;
3826
3827    let channel_id = ChannelId::from_proto(request.channel_id);
3828    let CreatedChannelMessage {
3829        message_id,
3830        participant_connection_ids,
3831        notifications,
3832    } = session
3833        .db()
3834        .await
3835        .create_channel_message(
3836            channel_id,
3837            session.user_id(),
3838            &body,
3839            &request.mentions,
3840            timestamp,
3841            nonce.clone().into(),
3842            request.reply_to_message_id.map(MessageId::from_proto),
3843        )
3844        .await?;
3845
3846    let message = proto::ChannelMessage {
3847        sender_id: session.user_id().to_proto(),
3848        id: message_id.to_proto(),
3849        body,
3850        mentions: request.mentions,
3851        timestamp: timestamp.unix_timestamp() as u64,
3852        nonce: Some(nonce),
3853        reply_to_message_id: request.reply_to_message_id,
3854        edited_at: None,
3855    };
3856    broadcast(
3857        Some(session.connection_id),
3858        participant_connection_ids.clone(),
3859        |connection| {
3860            session.peer.send(
3861                connection,
3862                proto::ChannelMessageSent {
3863                    channel_id: channel_id.to_proto(),
3864                    message: Some(message.clone()),
3865                },
3866            )
3867        },
3868    );
3869    response.send(proto::SendChannelMessageResponse {
3870        message: Some(message),
3871    })?;
3872
3873    let pool = &*session.connection_pool().await;
3874    let non_participants =
3875        pool.channel_connection_ids(channel_id)
3876            .filter_map(|(connection_id, _)| {
3877                if participant_connection_ids.contains(&connection_id) {
3878                    None
3879                } else {
3880                    Some(connection_id)
3881                }
3882            });
3883    broadcast(None, non_participants, |peer_id| {
3884        session.peer.send(
3885            peer_id,
3886            proto::UpdateChannels {
3887                latest_channel_message_ids: vec![proto::ChannelMessageId {
3888                    channel_id: channel_id.to_proto(),
3889                    message_id: message_id.to_proto(),
3890                }],
3891                ..Default::default()
3892            },
3893        )
3894    });
3895    send_notifications(pool, &session.peer, notifications);
3896
3897    Ok(())
3898}
3899
3900/// Delete a channel message
3901async fn remove_channel_message(
3902    request: proto::RemoveChannelMessage,
3903    response: Response<proto::RemoveChannelMessage>,
3904    session: MessageContext,
3905) -> Result<()> {
3906    let channel_id = ChannelId::from_proto(request.channel_id);
3907    let message_id = MessageId::from_proto(request.message_id);
3908    let (connection_ids, existing_notification_ids) = session
3909        .db()
3910        .await
3911        .remove_channel_message(channel_id, message_id, session.user_id())
3912        .await?;
3913
3914    broadcast(
3915        Some(session.connection_id),
3916        connection_ids,
3917        move |connection| {
3918            session.peer.send(connection, request.clone())?;
3919
3920            for notification_id in &existing_notification_ids {
3921                session.peer.send(
3922                    connection,
3923                    proto::DeleteNotification {
3924                        notification_id: (*notification_id).to_proto(),
3925                    },
3926                )?;
3927            }
3928
3929            Ok(())
3930        },
3931    );
3932    response.send(proto::Ack {})?;
3933    Ok(())
3934}
3935
3936async fn update_channel_message(
3937    request: proto::UpdateChannelMessage,
3938    response: Response<proto::UpdateChannelMessage>,
3939    session: MessageContext,
3940) -> Result<()> {
3941    let channel_id = ChannelId::from_proto(request.channel_id);
3942    let message_id = MessageId::from_proto(request.message_id);
3943    let updated_at = OffsetDateTime::now_utc();
3944    let UpdatedChannelMessage {
3945        message_id,
3946        participant_connection_ids,
3947        notifications,
3948        reply_to_message_id,
3949        timestamp,
3950        deleted_mention_notification_ids,
3951        updated_mention_notifications,
3952    } = session
3953        .db()
3954        .await
3955        .update_channel_message(
3956            channel_id,
3957            message_id,
3958            session.user_id(),
3959            request.body.as_str(),
3960            &request.mentions,
3961            updated_at,
3962        )
3963        .await?;
3964
3965    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3966
3967    let message = proto::ChannelMessage {
3968        sender_id: session.user_id().to_proto(),
3969        id: message_id.to_proto(),
3970        body: request.body.clone(),
3971        mentions: request.mentions.clone(),
3972        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3973        nonce: Some(nonce),
3974        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3975        edited_at: Some(updated_at.unix_timestamp() as u64),
3976    };
3977
3978    response.send(proto::Ack {})?;
3979
3980    let pool = &*session.connection_pool().await;
3981    broadcast(
3982        Some(session.connection_id),
3983        participant_connection_ids,
3984        |connection| {
3985            session.peer.send(
3986                connection,
3987                proto::ChannelMessageUpdate {
3988                    channel_id: channel_id.to_proto(),
3989                    message: Some(message.clone()),
3990                },
3991            )?;
3992
3993            for notification_id in &deleted_mention_notification_ids {
3994                session.peer.send(
3995                    connection,
3996                    proto::DeleteNotification {
3997                        notification_id: (*notification_id).to_proto(),
3998                    },
3999                )?;
4000            }
4001
4002            for notification in &updated_mention_notifications {
4003                session.peer.send(
4004                    connection,
4005                    proto::UpdateNotification {
4006                        notification: Some(notification.clone()),
4007                    },
4008                )?;
4009            }
4010
4011            Ok(())
4012        },
4013    );
4014
4015    send_notifications(pool, &session.peer, notifications);
4016
4017    Ok(())
4018}
4019
4020/// Mark a channel message as read
4021async fn acknowledge_channel_message(
4022    request: proto::AckChannelMessage,
4023    session: MessageContext,
4024) -> Result<()> {
4025    let channel_id = ChannelId::from_proto(request.channel_id);
4026    let message_id = MessageId::from_proto(request.message_id);
4027    let notifications = session
4028        .db()
4029        .await
4030        .observe_channel_message(channel_id, session.user_id(), message_id)
4031        .await?;
4032    send_notifications(
4033        &*session.connection_pool().await,
4034        &session.peer,
4035        notifications,
4036    );
4037    Ok(())
4038}
4039
4040/// Mark a buffer version as synced
4041async fn acknowledge_buffer_version(
4042    request: proto::AckBufferOperation,
4043    session: MessageContext,
4044) -> Result<()> {
4045    let buffer_id = BufferId::from_proto(request.buffer_id);
4046    session
4047        .db()
4048        .await
4049        .observe_buffer_version(
4050            buffer_id,
4051            session.user_id(),
4052            request.epoch as i32,
4053            &request.version,
4054        )
4055        .await?;
4056    Ok(())
4057}
4058
4059/// Get a Supermaven API key for the user
4060async fn get_supermaven_api_key(
4061    _request: proto::GetSupermavenApiKey,
4062    response: Response<proto::GetSupermavenApiKey>,
4063    session: MessageContext,
4064) -> Result<()> {
4065    let user_id: String = session.user_id().to_string();
4066    if !session.is_staff() {
4067        return Err(anyhow!("supermaven not enabled for this account"))?;
4068    }
4069
4070    let email = session.email().context("user must have an email")?;
4071
4072    let supermaven_admin_api = session
4073        .supermaven_client
4074        .as_ref()
4075        .context("supermaven not configured")?;
4076
4077    let result = supermaven_admin_api
4078        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4079        .await?;
4080
4081    response.send(proto::GetSupermavenApiKeyResponse {
4082        api_key: result.api_key,
4083    })?;
4084
4085    Ok(())
4086}
4087
4088/// Start receiving chat updates for a channel
4089async fn join_channel_chat(
4090    request: proto::JoinChannelChat,
4091    response: Response<proto::JoinChannelChat>,
4092    session: MessageContext,
4093) -> Result<()> {
4094    let channel_id = ChannelId::from_proto(request.channel_id);
4095
4096    let db = session.db().await;
4097    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4098        .await?;
4099    let messages = db
4100        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4101        .await?;
4102    response.send(proto::JoinChannelChatResponse {
4103        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4104        messages,
4105    })?;
4106    Ok(())
4107}
4108
4109/// Stop receiving chat updates for a channel
4110async fn leave_channel_chat(
4111    request: proto::LeaveChannelChat,
4112    session: MessageContext,
4113) -> Result<()> {
4114    let channel_id = ChannelId::from_proto(request.channel_id);
4115    session
4116        .db()
4117        .await
4118        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4119        .await?;
4120    Ok(())
4121}
4122
4123/// Retrieve the chat history for a channel
4124async fn get_channel_messages(
4125    request: proto::GetChannelMessages,
4126    response: Response<proto::GetChannelMessages>,
4127    session: MessageContext,
4128) -> Result<()> {
4129    let channel_id = ChannelId::from_proto(request.channel_id);
4130    let messages = session
4131        .db()
4132        .await
4133        .get_channel_messages(
4134            channel_id,
4135            session.user_id(),
4136            MESSAGE_COUNT_PER_PAGE,
4137            Some(MessageId::from_proto(request.before_message_id)),
4138        )
4139        .await?;
4140    response.send(proto::GetChannelMessagesResponse {
4141        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4142        messages,
4143    })?;
4144    Ok(())
4145}
4146
4147/// Retrieve specific chat messages
4148async fn get_channel_messages_by_id(
4149    request: proto::GetChannelMessagesById,
4150    response: Response<proto::GetChannelMessagesById>,
4151    session: MessageContext,
4152) -> Result<()> {
4153    let message_ids = request
4154        .message_ids
4155        .iter()
4156        .map(|id| MessageId::from_proto(*id))
4157        .collect::<Vec<_>>();
4158    let messages = session
4159        .db()
4160        .await
4161        .get_channel_messages_by_id(session.user_id(), &message_ids)
4162        .await?;
4163    response.send(proto::GetChannelMessagesResponse {
4164        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4165        messages,
4166    })?;
4167    Ok(())
4168}
4169
4170/// Retrieve the current users notifications
4171async fn get_notifications(
4172    request: proto::GetNotifications,
4173    response: Response<proto::GetNotifications>,
4174    session: MessageContext,
4175) -> Result<()> {
4176    let notifications = session
4177        .db()
4178        .await
4179        .get_notifications(
4180            session.user_id(),
4181            NOTIFICATION_COUNT_PER_PAGE,
4182            request.before_id.map(db::NotificationId::from_proto),
4183        )
4184        .await?;
4185    response.send(proto::GetNotificationsResponse {
4186        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4187        notifications,
4188    })?;
4189    Ok(())
4190}
4191
4192/// Mark notifications as read
4193async fn mark_notification_as_read(
4194    request: proto::MarkNotificationRead,
4195    response: Response<proto::MarkNotificationRead>,
4196    session: MessageContext,
4197) -> Result<()> {
4198    let database = &session.db().await;
4199    let notifications = database
4200        .mark_notification_as_read_by_id(
4201            session.user_id(),
4202            NotificationId::from_proto(request.notification_id),
4203        )
4204        .await?;
4205    send_notifications(
4206        &*session.connection_pool().await,
4207        &session.peer,
4208        notifications,
4209    );
4210    response.send(proto::Ack {})?;
4211    Ok(())
4212}
4213
4214/// Get the current users information
4215async fn get_private_user_info(
4216    _request: proto::GetPrivateUserInfo,
4217    response: Response<proto::GetPrivateUserInfo>,
4218    session: MessageContext,
4219) -> Result<()> {
4220    let db = session.db().await;
4221
4222    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4223    let user = db
4224        .get_user_by_id(session.user_id())
4225        .await?
4226        .context("user not found")?;
4227    let flags = db.get_user_flags(session.user_id()).await?;
4228
4229    response.send(proto::GetPrivateUserInfoResponse {
4230        metrics_id,
4231        staff: user.admin,
4232        flags,
4233        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4234    })?;
4235    Ok(())
4236}
4237
4238/// Accept the terms of service (tos) on behalf of the current user
4239async fn accept_terms_of_service(
4240    _request: proto::AcceptTermsOfService,
4241    response: Response<proto::AcceptTermsOfService>,
4242    session: MessageContext,
4243) -> Result<()> {
4244    let db = session.db().await;
4245
4246    let accepted_tos_at = Utc::now();
4247    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4248        .await?;
4249
4250    response.send(proto::AcceptTermsOfServiceResponse {
4251        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4252    })?;
4253
4254    // When the user accepts the terms of service, we want to refresh their LLM
4255    // token to grant access.
4256    session
4257        .peer
4258        .send(session.connection_id, proto::RefreshLlmToken {})?;
4259
4260    Ok(())
4261}
4262
4263async fn get_llm_api_token(
4264    _request: proto::GetLlmToken,
4265    response: Response<proto::GetLlmToken>,
4266    session: MessageContext,
4267) -> Result<()> {
4268    let db = session.db().await;
4269
4270    let flags = db.get_user_flags(session.user_id()).await?;
4271
4272    let user_id = session.user_id();
4273    let user = db
4274        .get_user_by_id(user_id)
4275        .await?
4276        .with_context(|| format!("user {user_id} not found"))?;
4277
4278    if user.accepted_tos_at.is_none() {
4279        Err(anyhow!("terms of service not accepted"))?
4280    }
4281
4282    let stripe_client = session
4283        .app_state
4284        .stripe_client
4285        .as_ref()
4286        .context("failed to retrieve Stripe client")?;
4287
4288    let stripe_billing = session
4289        .app_state
4290        .stripe_billing
4291        .as_ref()
4292        .context("failed to retrieve Stripe billing object")?;
4293
4294    let billing_customer = if let Some(billing_customer) =
4295        db.get_billing_customer_by_user_id(user.id).await?
4296    {
4297        billing_customer
4298    } else {
4299        let customer_id = stripe_billing
4300            .find_or_create_customer_by_email(user.email_address.as_deref())
4301            .await?;
4302
4303        find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4304            .await?
4305            .context("billing customer not found")?
4306    };
4307
4308    let billing_subscription =
4309        if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4310            billing_subscription
4311        } else {
4312            let stripe_customer_id =
4313                StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4314
4315            let stripe_subscription = stripe_billing
4316                .subscribe_to_zed_free(stripe_customer_id)
4317                .await?;
4318
4319            db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4320                billing_customer_id: billing_customer.id,
4321                kind: Some(SubscriptionKind::ZedFree),
4322                stripe_subscription_id: stripe_subscription.id.to_string(),
4323                stripe_subscription_status: stripe_subscription.status.into(),
4324                stripe_cancellation_reason: None,
4325                stripe_current_period_start: Some(stripe_subscription.current_period_start),
4326                stripe_current_period_end: Some(stripe_subscription.current_period_end),
4327            })
4328            .await?
4329        };
4330
4331    let billing_preferences = db.get_billing_preferences(user.id).await?;
4332
4333    let token = LlmTokenClaims::create(
4334        &user,
4335        session.is_staff(),
4336        billing_customer,
4337        billing_preferences,
4338        &flags,
4339        billing_subscription,
4340        session.system_id.clone(),
4341        &session.app_state.config,
4342    )?;
4343    response.send(proto::GetLlmTokenResponse { token })?;
4344    Ok(())
4345}
4346
4347fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4348    let message = match message {
4349        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4350        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4351        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4352        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4353        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4354            code: frame.code.into(),
4355            reason: frame.reason.as_str().to_owned().into(),
4356        })),
4357        // We should never receive a frame while reading the message, according
4358        // to the `tungstenite` maintainers:
4359        //
4360        // > It cannot occur when you read messages from the WebSocket, but it
4361        // > can be used when you want to send the raw frames (e.g. you want to
4362        // > send the frames to the WebSocket without composing the full message first).
4363        // >
4364        // > — https://github.com/snapview/tungstenite-rs/issues/268
4365        TungsteniteMessage::Frame(_) => {
4366            bail!("received an unexpected frame while reading the message")
4367        }
4368    };
4369
4370    Ok(message)
4371}
4372
4373fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4374    match message {
4375        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4376        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4377        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4378        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4379        AxumMessage::Close(frame) => {
4380            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4381                code: frame.code.into(),
4382                reason: frame.reason.as_ref().into(),
4383            }))
4384        }
4385    }
4386}
4387
4388fn notify_membership_updated(
4389    connection_pool: &mut ConnectionPool,
4390    result: MembershipUpdated,
4391    user_id: UserId,
4392    peer: &Peer,
4393) {
4394    for membership in &result.new_channels.channel_memberships {
4395        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4396    }
4397    for channel_id in &result.removed_channels {
4398        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4399    }
4400
4401    let user_channels_update = proto::UpdateUserChannels {
4402        channel_memberships: result
4403            .new_channels
4404            .channel_memberships
4405            .iter()
4406            .map(|cm| proto::ChannelMembership {
4407                channel_id: cm.channel_id.to_proto(),
4408                role: cm.role.into(),
4409            })
4410            .collect(),
4411        ..Default::default()
4412    };
4413
4414    let mut update = build_channels_update(result.new_channels);
4415    update.delete_channels = result
4416        .removed_channels
4417        .into_iter()
4418        .map(|id| id.to_proto())
4419        .collect();
4420    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4421
4422    for connection_id in connection_pool.user_connection_ids(user_id) {
4423        peer.send(connection_id, user_channels_update.clone())
4424            .trace_err();
4425        peer.send(connection_id, update.clone()).trace_err();
4426    }
4427}
4428
4429fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4430    proto::UpdateUserChannels {
4431        channel_memberships: channels
4432            .channel_memberships
4433            .iter()
4434            .map(|m| proto::ChannelMembership {
4435                channel_id: m.channel_id.to_proto(),
4436                role: m.role.into(),
4437            })
4438            .collect(),
4439        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4440        observed_channel_message_id: channels.observed_channel_messages.clone(),
4441    }
4442}
4443
4444fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4445    let mut update = proto::UpdateChannels::default();
4446
4447    for channel in channels.channels {
4448        update.channels.push(channel.to_proto());
4449    }
4450
4451    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4452    update.latest_channel_message_ids = channels.latest_channel_messages;
4453
4454    for (channel_id, participants) in channels.channel_participants {
4455        update
4456            .channel_participants
4457            .push(proto::ChannelParticipants {
4458                channel_id: channel_id.to_proto(),
4459                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4460            });
4461    }
4462
4463    for channel in channels.invited_channels {
4464        update.channel_invitations.push(channel.to_proto());
4465    }
4466
4467    update
4468}
4469
4470fn build_initial_contacts_update(
4471    contacts: Vec<db::Contact>,
4472    pool: &ConnectionPool,
4473) -> proto::UpdateContacts {
4474    let mut update = proto::UpdateContacts::default();
4475
4476    for contact in contacts {
4477        match contact {
4478            db::Contact::Accepted { user_id, busy } => {
4479                update.contacts.push(contact_for_user(user_id, busy, pool));
4480            }
4481            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4482            db::Contact::Incoming { user_id } => {
4483                update
4484                    .incoming_requests
4485                    .push(proto::IncomingContactRequest {
4486                        requester_id: user_id.to_proto(),
4487                    })
4488            }
4489        }
4490    }
4491
4492    update
4493}
4494
4495fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4496    proto::Contact {
4497        user_id: user_id.to_proto(),
4498        online: pool.is_user_online(user_id),
4499        busy,
4500    }
4501}
4502
4503fn room_updated(room: &proto::Room, peer: &Peer) {
4504    broadcast(
4505        None,
4506        room.participants
4507            .iter()
4508            .filter_map(|participant| Some(participant.peer_id?.into())),
4509        |peer_id| {
4510            peer.send(
4511                peer_id,
4512                proto::RoomUpdated {
4513                    room: Some(room.clone()),
4514                },
4515            )
4516        },
4517    );
4518}
4519
4520fn channel_updated(
4521    channel: &db::channel::Model,
4522    room: &proto::Room,
4523    peer: &Peer,
4524    pool: &ConnectionPool,
4525) {
4526    let participants = room
4527        .participants
4528        .iter()
4529        .map(|p| p.user_id)
4530        .collect::<Vec<_>>();
4531
4532    broadcast(
4533        None,
4534        pool.channel_connection_ids(channel.root_id())
4535            .filter_map(|(channel_id, role)| {
4536                role.can_see_channel(channel.visibility)
4537                    .then_some(channel_id)
4538            }),
4539        |peer_id| {
4540            peer.send(
4541                peer_id,
4542                proto::UpdateChannels {
4543                    channel_participants: vec![proto::ChannelParticipants {
4544                        channel_id: channel.id.to_proto(),
4545                        participant_user_ids: participants.clone(),
4546                    }],
4547                    ..Default::default()
4548                },
4549            )
4550        },
4551    );
4552}
4553
4554async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4555    let db = session.db().await;
4556
4557    let contacts = db.get_contacts(user_id).await?;
4558    let busy = db.is_user_busy(user_id).await?;
4559
4560    let pool = session.connection_pool().await;
4561    let updated_contact = contact_for_user(user_id, busy, &pool);
4562    for contact in contacts {
4563        if let db::Contact::Accepted {
4564            user_id: contact_user_id,
4565            ..
4566        } = contact
4567        {
4568            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4569                session
4570                    .peer
4571                    .send(
4572                        contact_conn_id,
4573                        proto::UpdateContacts {
4574                            contacts: vec![updated_contact.clone()],
4575                            remove_contacts: Default::default(),
4576                            incoming_requests: Default::default(),
4577                            remove_incoming_requests: Default::default(),
4578                            outgoing_requests: Default::default(),
4579                            remove_outgoing_requests: Default::default(),
4580                        },
4581                    )
4582                    .trace_err();
4583            }
4584        }
4585    }
4586    Ok(())
4587}
4588
4589async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4590    let mut contacts_to_update = HashSet::default();
4591
4592    let room_id;
4593    let canceled_calls_to_user_ids;
4594    let livekit_room;
4595    let delete_livekit_room;
4596    let room;
4597    let channel;
4598
4599    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4600        contacts_to_update.insert(session.user_id());
4601
4602        for project in left_room.left_projects.values() {
4603            project_left(project, session);
4604        }
4605
4606        room_id = RoomId::from_proto(left_room.room.id);
4607        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4608        livekit_room = mem::take(&mut left_room.room.livekit_room);
4609        delete_livekit_room = left_room.deleted;
4610        room = mem::take(&mut left_room.room);
4611        channel = mem::take(&mut left_room.channel);
4612
4613        room_updated(&room, &session.peer);
4614    } else {
4615        return Ok(());
4616    }
4617
4618    if let Some(channel) = channel {
4619        channel_updated(
4620            &channel,
4621            &room,
4622            &session.peer,
4623            &*session.connection_pool().await,
4624        );
4625    }
4626
4627    {
4628        let pool = session.connection_pool().await;
4629        for canceled_user_id in canceled_calls_to_user_ids {
4630            for connection_id in pool.user_connection_ids(canceled_user_id) {
4631                session
4632                    .peer
4633                    .send(
4634                        connection_id,
4635                        proto::CallCanceled {
4636                            room_id: room_id.to_proto(),
4637                        },
4638                    )
4639                    .trace_err();
4640            }
4641            contacts_to_update.insert(canceled_user_id);
4642        }
4643    }
4644
4645    for contact_user_id in contacts_to_update {
4646        update_user_contacts(contact_user_id, session).await?;
4647    }
4648
4649    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4650        live_kit
4651            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4652            .await
4653            .trace_err();
4654
4655        if delete_livekit_room {
4656            live_kit.delete_room(livekit_room).await.trace_err();
4657        }
4658    }
4659
4660    Ok(())
4661}
4662
4663async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4664    let left_channel_buffers = session
4665        .db()
4666        .await
4667        .leave_channel_buffers(session.connection_id)
4668        .await?;
4669
4670    for left_buffer in left_channel_buffers {
4671        channel_buffer_updated(
4672            session.connection_id,
4673            left_buffer.connections,
4674            &proto::UpdateChannelBufferCollaborators {
4675                channel_id: left_buffer.channel_id.to_proto(),
4676                collaborators: left_buffer.collaborators,
4677            },
4678            &session.peer,
4679        );
4680    }
4681
4682    Ok(())
4683}
4684
4685fn project_left(project: &db::LeftProject, session: &Session) {
4686    for connection_id in &project.connection_ids {
4687        if project.should_unshare {
4688            session
4689                .peer
4690                .send(
4691                    *connection_id,
4692                    proto::UnshareProject {
4693                        project_id: project.id.to_proto(),
4694                    },
4695                )
4696                .trace_err();
4697        } else {
4698            session
4699                .peer
4700                .send(
4701                    *connection_id,
4702                    proto::RemoveProjectCollaborator {
4703                        project_id: project.id.to_proto(),
4704                        peer_id: Some(session.connection_id.into()),
4705                    },
4706                )
4707                .trace_err();
4708        }
4709    }
4710}
4711
4712pub trait ResultExt {
4713    type Ok;
4714
4715    fn trace_err(self) -> Option<Self::Ok>;
4716}
4717
4718impl<T, E> ResultExt for Result<T, E>
4719where
4720    E: std::fmt::Debug,
4721{
4722    type Ok = T;
4723
4724    #[track_caller]
4725    fn trace_err(self) -> Option<T> {
4726        match self {
4727            Ok(value) => Some(value),
4728            Err(error) => {
4729                tracing::error!("{:?}", error);
4730                None
4731            }
4732        }
4733    }
4734}