rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::db::billing_subscription::SubscriptionKind;
   5use crate::llm::db::LlmDatabase;
   6use crate::llm::{
   7    AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG,
   8    MIN_ACCOUNT_AGE_FOR_LLM_USE,
   9};
  10use crate::{
  11    AppState, Error, Result, auth,
  12    db::{
  13        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
  14        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  15        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  16        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  17    },
  18    executor::Executor,
  19};
  20use anyhow::{Context as _, anyhow, bail};
  21use async_tungstenite::tungstenite::{
  22    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  23};
  24use axum::headers::UserAgent;
  25use axum::{
  26    Extension, Router, TypedHeader,
  27    body::Body,
  28    extract::{
  29        ConnectInfo, WebSocketUpgrade,
  30        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  31    },
  32    headers::{Header, HeaderName},
  33    http::StatusCode,
  34    middleware,
  35    response::IntoResponse,
  36    routing::get,
  37};
  38use chrono::Utc;
  39use collections::{HashMap, HashSet};
  40pub use connection_pool::{ConnectionPool, ZedVersion};
  41use core::fmt::{self, Debug, Formatter};
  42use futures::TryFutureExt as _;
  43use reqwest_client::ReqwestClient;
  44use rpc::proto::{MultiLspQuery, split_repository_update};
  45use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  46use tracing::Span;
  47
  48use futures::{
  49    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  50    stream::FuturesUnordered,
  51};
  52use prometheus::{IntGauge, register_int_gauge};
  53use rpc::{
  54    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  55    proto::{
  56        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  57        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  58    },
  59};
  60use semantic_version::SemanticVersion;
  61use serde::{Serialize, Serializer};
  62use std::{
  63    any::TypeId,
  64    future::Future,
  65    marker::PhantomData,
  66    mem,
  67    net::SocketAddr,
  68    ops::{Deref, DerefMut},
  69    rc::Rc,
  70    sync::{
  71        Arc, OnceLock,
  72        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  73    },
  74    time::{Duration, Instant},
  75};
  76use time::OffsetDateTime;
  77use tokio::sync::{Semaphore, watch};
  78use tower::ServiceBuilder;
  79use tracing::{
  80    Instrument,
  81    field::{self},
  82    info_span, instrument,
  83};
  84
  85pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  86
  87// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  88pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  89
  90const MESSAGE_COUNT_PER_PAGE: usize = 100;
  91const MAX_MESSAGE_LEN: usize = 1024;
  92const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  93const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  94
  95static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  96
  97const TOTAL_DURATION_MS: &str = "total_duration_ms";
  98const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  99const QUEUE_DURATION_MS: &str = "queue_duration_ms";
 100const HOST_WAITING_MS: &str = "host_waiting_ms";
 101
 102type MessageHandler =
 103    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
 104
 105pub struct ConnectionGuard;
 106
 107impl ConnectionGuard {
 108    pub fn try_acquire() -> Result<Self, ()> {
 109        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 110        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 111            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 112            tracing::error!(
 113                "too many concurrent connections: {}",
 114                current_connections + 1
 115            );
 116            return Err(());
 117        }
 118        Ok(ConnectionGuard)
 119    }
 120}
 121
 122impl Drop for ConnectionGuard {
 123    fn drop(&mut self) {
 124        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 125    }
 126}
 127
 128struct Response<R> {
 129    peer: Arc<Peer>,
 130    receipt: Receipt<R>,
 131    responded: Arc<AtomicBool>,
 132}
 133
 134impl<R: RequestMessage> Response<R> {
 135    fn send(self, payload: R::Response) -> Result<()> {
 136        self.responded.store(true, SeqCst);
 137        self.peer.respond(self.receipt, payload)?;
 138        Ok(())
 139    }
 140}
 141
 142#[derive(Clone, Debug)]
 143pub enum Principal {
 144    User(User),
 145    Impersonated { user: User, admin: User },
 146}
 147
 148impl Principal {
 149    fn user(&self) -> &User {
 150        match self {
 151            Principal::User(user) => user,
 152            Principal::Impersonated { user, .. } => user,
 153        }
 154    }
 155
 156    fn update_span(&self, span: &tracing::Span) {
 157        match &self {
 158            Principal::User(user) => {
 159                span.record("user_id", user.id.0);
 160                span.record("login", &user.github_login);
 161            }
 162            Principal::Impersonated { user, admin } => {
 163                span.record("user_id", user.id.0);
 164                span.record("login", &user.github_login);
 165                span.record("impersonator", &admin.github_login);
 166            }
 167        }
 168    }
 169}
 170
 171#[derive(Clone)]
 172struct MessageContext {
 173    session: Session,
 174    span: tracing::Span,
 175}
 176
 177impl Deref for MessageContext {
 178    type Target = Session;
 179
 180    fn deref(&self) -> &Self::Target {
 181        &self.session
 182    }
 183}
 184
 185impl MessageContext {
 186    pub fn forward_request<T: RequestMessage>(
 187        &self,
 188        receiver_id: ConnectionId,
 189        request: T,
 190    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 191        let request_start_time = Instant::now();
 192        let span = self.span.clone();
 193        tracing::info!("start forwarding request");
 194        self.peer
 195            .forward_request(self.connection_id, receiver_id, request)
 196            .inspect(move |_| {
 197                span.record(
 198                    HOST_WAITING_MS,
 199                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 200                );
 201            })
 202            .inspect_err(|_| tracing::error!("error forwarding request"))
 203            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 204    }
 205}
 206
 207#[derive(Clone)]
 208struct Session {
 209    principal: Principal,
 210    connection_id: ConnectionId,
 211    db: Arc<tokio::sync::Mutex<DbHandle>>,
 212    peer: Arc<Peer>,
 213    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 214    app_state: Arc<AppState>,
 215    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 216    /// The GeoIP country code for the user.
 217    #[allow(unused)]
 218    geoip_country_code: Option<String>,
 219    #[allow(unused)]
 220    system_id: Option<String>,
 221    _executor: Executor,
 222}
 223
 224impl Session {
 225    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 226        #[cfg(test)]
 227        tokio::task::yield_now().await;
 228        let guard = self.db.lock().await;
 229        #[cfg(test)]
 230        tokio::task::yield_now().await;
 231        guard
 232    }
 233
 234    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 235        #[cfg(test)]
 236        tokio::task::yield_now().await;
 237        let guard = self.connection_pool.lock();
 238        ConnectionPoolGuard {
 239            guard,
 240            _not_send: PhantomData,
 241        }
 242    }
 243
 244    fn is_staff(&self) -> bool {
 245        match &self.principal {
 246            Principal::User(user) => user.admin,
 247            Principal::Impersonated { .. } => true,
 248        }
 249    }
 250
 251    fn user_id(&self) -> UserId {
 252        match &self.principal {
 253            Principal::User(user) => user.id,
 254            Principal::Impersonated { user, .. } => user.id,
 255        }
 256    }
 257
 258    pub fn email(&self) -> Option<String> {
 259        match &self.principal {
 260            Principal::User(user) => user.email_address.clone(),
 261            Principal::Impersonated { user, .. } => user.email_address.clone(),
 262        }
 263    }
 264}
 265
 266impl Debug for Session {
 267    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 268        let mut result = f.debug_struct("Session");
 269        match &self.principal {
 270            Principal::User(user) => {
 271                result.field("user", &user.github_login);
 272            }
 273            Principal::Impersonated { user, admin } => {
 274                result.field("user", &user.github_login);
 275                result.field("impersonator", &admin.github_login);
 276            }
 277        }
 278        result.field("connection_id", &self.connection_id).finish()
 279    }
 280}
 281
 282struct DbHandle(Arc<Database>);
 283
 284impl Deref for DbHandle {
 285    type Target = Database;
 286
 287    fn deref(&self) -> &Self::Target {
 288        self.0.as_ref()
 289    }
 290}
 291
 292pub struct Server {
 293    id: parking_lot::Mutex<ServerId>,
 294    peer: Arc<Peer>,
 295    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 296    app_state: Arc<AppState>,
 297    handlers: HashMap<TypeId, MessageHandler>,
 298    teardown: watch::Sender<bool>,
 299}
 300
 301pub(crate) struct ConnectionPoolGuard<'a> {
 302    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 303    _not_send: PhantomData<Rc<()>>,
 304}
 305
 306#[derive(Serialize)]
 307pub struct ServerSnapshot<'a> {
 308    peer: &'a Peer,
 309    #[serde(serialize_with = "serialize_deref")]
 310    connection_pool: ConnectionPoolGuard<'a>,
 311}
 312
 313pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 314where
 315    S: Serializer,
 316    T: Deref<Target = U>,
 317    U: Serialize,
 318{
 319    Serialize::serialize(value.deref(), serializer)
 320}
 321
 322impl Server {
 323    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 324        let mut server = Self {
 325            id: parking_lot::Mutex::new(id),
 326            peer: Peer::new(id.0 as u32),
 327            app_state: app_state.clone(),
 328            connection_pool: Default::default(),
 329            handlers: Default::default(),
 330            teardown: watch::channel(false).0,
 331        };
 332
 333        server
 334            .add_request_handler(ping)
 335            .add_request_handler(create_room)
 336            .add_request_handler(join_room)
 337            .add_request_handler(rejoin_room)
 338            .add_request_handler(leave_room)
 339            .add_request_handler(set_room_participant_role)
 340            .add_request_handler(call)
 341            .add_request_handler(cancel_call)
 342            .add_message_handler(decline_call)
 343            .add_request_handler(update_participant_location)
 344            .add_request_handler(share_project)
 345            .add_message_handler(unshare_project)
 346            .add_request_handler(join_project)
 347            .add_message_handler(leave_project)
 348            .add_request_handler(update_project)
 349            .add_request_handler(update_worktree)
 350            .add_request_handler(update_repository)
 351            .add_request_handler(remove_repository)
 352            .add_message_handler(start_language_server)
 353            .add_message_handler(update_language_server)
 354            .add_message_handler(update_diagnostic_summary)
 355            .add_message_handler(update_worktree_settings)
 356            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 357            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 358            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 359            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 360            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 361            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 362            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 363            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 364            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 365            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 366            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 367            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 368            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 369            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 370            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 371            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 372            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 373            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 374            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 375            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 376            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 377            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 378            .add_request_handler(
 379                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 380            )
 381            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 382            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 383            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 384            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 385            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 386            .add_request_handler(
 387                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 388            )
 389            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 390            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 391            .add_request_handler(
 392                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 393            )
 394            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 395            .add_request_handler(
 396                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 397            )
 398            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 399            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 400            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 401            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 402            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 403            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 404            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 405            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 406            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 407            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 408            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 409            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 410            .add_request_handler(
 411                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 412            )
 413            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 414            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 415            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 416            .add_request_handler(multi_lsp_query)
 417            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 418            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 419            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 420            .add_message_handler(create_buffer_for_peer)
 421            .add_request_handler(update_buffer)
 422            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 423            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 424            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 425            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 426            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 427            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 428            .add_message_handler(
 429                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 430            )
 431            .add_request_handler(get_users)
 432            .add_request_handler(fuzzy_search_users)
 433            .add_request_handler(request_contact)
 434            .add_request_handler(remove_contact)
 435            .add_request_handler(respond_to_contact_request)
 436            .add_message_handler(subscribe_to_channels)
 437            .add_request_handler(create_channel)
 438            .add_request_handler(delete_channel)
 439            .add_request_handler(invite_channel_member)
 440            .add_request_handler(remove_channel_member)
 441            .add_request_handler(set_channel_member_role)
 442            .add_request_handler(set_channel_visibility)
 443            .add_request_handler(rename_channel)
 444            .add_request_handler(join_channel_buffer)
 445            .add_request_handler(leave_channel_buffer)
 446            .add_message_handler(update_channel_buffer)
 447            .add_request_handler(rejoin_channel_buffers)
 448            .add_request_handler(get_channel_members)
 449            .add_request_handler(respond_to_channel_invite)
 450            .add_request_handler(join_channel)
 451            .add_request_handler(join_channel_chat)
 452            .add_message_handler(leave_channel_chat)
 453            .add_request_handler(send_channel_message)
 454            .add_request_handler(remove_channel_message)
 455            .add_request_handler(update_channel_message)
 456            .add_request_handler(get_channel_messages)
 457            .add_request_handler(get_channel_messages_by_id)
 458            .add_request_handler(get_notifications)
 459            .add_request_handler(mark_notification_as_read)
 460            .add_request_handler(move_channel)
 461            .add_request_handler(reorder_channel)
 462            .add_request_handler(follow)
 463            .add_message_handler(unfollow)
 464            .add_message_handler(update_followers)
 465            .add_request_handler(accept_terms_of_service)
 466            .add_message_handler(acknowledge_channel_message)
 467            .add_message_handler(acknowledge_buffer_version)
 468            .add_request_handler(get_supermaven_api_key)
 469            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 470            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 471            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 472            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 473            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 474            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 475            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 476            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 477            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 478            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 479            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 480            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 481            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 482            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 483            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 484            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 485            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 486            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 487            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 488            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 489            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 490            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 491            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 492            .add_message_handler(update_context);
 493
 494        Arc::new(server)
 495    }
 496
 497    pub async fn start(&self) -> Result<()> {
 498        let server_id = *self.id.lock();
 499        let app_state = self.app_state.clone();
 500        let peer = self.peer.clone();
 501        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 502        let pool = self.connection_pool.clone();
 503        let livekit_client = self.app_state.livekit_client.clone();
 504
 505        let span = info_span!("start server");
 506        self.app_state.executor.spawn_detached(
 507            async move {
 508                tracing::info!("waiting for cleanup timeout");
 509                timeout.await;
 510                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 511
 512                app_state
 513                    .db
 514                    .delete_stale_channel_chat_participants(
 515                        &app_state.config.zed_environment,
 516                        server_id,
 517                    )
 518                    .await
 519                    .trace_err();
 520
 521                if let Some((room_ids, channel_ids)) = app_state
 522                    .db
 523                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 524                    .await
 525                    .trace_err()
 526                {
 527                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 528                    tracing::info!(
 529                        stale_channel_buffer_count = channel_ids.len(),
 530                        "retrieved stale channel buffers"
 531                    );
 532
 533                    for channel_id in channel_ids {
 534                        if let Some(refreshed_channel_buffer) = app_state
 535                            .db
 536                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 537                            .await
 538                            .trace_err()
 539                        {
 540                            for connection_id in refreshed_channel_buffer.connection_ids {
 541                                peer.send(
 542                                    connection_id,
 543                                    proto::UpdateChannelBufferCollaborators {
 544                                        channel_id: channel_id.to_proto(),
 545                                        collaborators: refreshed_channel_buffer
 546                                            .collaborators
 547                                            .clone(),
 548                                    },
 549                                )
 550                                .trace_err();
 551                            }
 552                        }
 553                    }
 554
 555                    for room_id in room_ids {
 556                        let mut contacts_to_update = HashSet::default();
 557                        let mut canceled_calls_to_user_ids = Vec::new();
 558                        let mut livekit_room = String::new();
 559                        let mut delete_livekit_room = false;
 560
 561                        if let Some(mut refreshed_room) = app_state
 562                            .db
 563                            .clear_stale_room_participants(room_id, server_id)
 564                            .await
 565                            .trace_err()
 566                        {
 567                            tracing::info!(
 568                                room_id = room_id.0,
 569                                new_participant_count = refreshed_room.room.participants.len(),
 570                                "refreshed room"
 571                            );
 572                            room_updated(&refreshed_room.room, &peer);
 573                            if let Some(channel) = refreshed_room.channel.as_ref() {
 574                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 575                            }
 576                            contacts_to_update
 577                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 578                            contacts_to_update
 579                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 580                            canceled_calls_to_user_ids =
 581                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 582                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 583                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 584                        }
 585
 586                        {
 587                            let pool = pool.lock();
 588                            for canceled_user_id in canceled_calls_to_user_ids {
 589                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 590                                    peer.send(
 591                                        connection_id,
 592                                        proto::CallCanceled {
 593                                            room_id: room_id.to_proto(),
 594                                        },
 595                                    )
 596                                    .trace_err();
 597                                }
 598                            }
 599                        }
 600
 601                        for user_id in contacts_to_update {
 602                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 603                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 604                            if let Some((busy, contacts)) = busy.zip(contacts) {
 605                                let pool = pool.lock();
 606                                let updated_contact = contact_for_user(user_id, busy, &pool);
 607                                for contact in contacts {
 608                                    if let db::Contact::Accepted {
 609                                        user_id: contact_user_id,
 610                                        ..
 611                                    } = contact
 612                                    {
 613                                        for contact_conn_id in
 614                                            pool.user_connection_ids(contact_user_id)
 615                                        {
 616                                            peer.send(
 617                                                contact_conn_id,
 618                                                proto::UpdateContacts {
 619                                                    contacts: vec![updated_contact.clone()],
 620                                                    remove_contacts: Default::default(),
 621                                                    incoming_requests: Default::default(),
 622                                                    remove_incoming_requests: Default::default(),
 623                                                    outgoing_requests: Default::default(),
 624                                                    remove_outgoing_requests: Default::default(),
 625                                                },
 626                                            )
 627                                            .trace_err();
 628                                        }
 629                                    }
 630                                }
 631                            }
 632                        }
 633
 634                        if let Some(live_kit) = livekit_client.as_ref() {
 635                            if delete_livekit_room {
 636                                live_kit.delete_room(livekit_room).await.trace_err();
 637                            }
 638                        }
 639                    }
 640                }
 641
 642                app_state
 643                    .db
 644                    .delete_stale_channel_chat_participants(
 645                        &app_state.config.zed_environment,
 646                        server_id,
 647                    )
 648                    .await
 649                    .trace_err();
 650
 651                app_state
 652                    .db
 653                    .clear_old_worktree_entries(server_id)
 654                    .await
 655                    .trace_err();
 656
 657                app_state
 658                    .db
 659                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 660                    .await
 661                    .trace_err();
 662            }
 663            .instrument(span),
 664        );
 665        Ok(())
 666    }
 667
 668    pub fn teardown(&self) {
 669        self.peer.teardown();
 670        self.connection_pool.lock().reset();
 671        let _ = self.teardown.send(true);
 672    }
 673
 674    #[cfg(test)]
 675    pub fn reset(&self, id: ServerId) {
 676        self.teardown();
 677        *self.id.lock() = id;
 678        self.peer.reset(id.0 as u32);
 679        let _ = self.teardown.send(false);
 680    }
 681
 682    #[cfg(test)]
 683    pub fn id(&self) -> ServerId {
 684        *self.id.lock()
 685    }
 686
 687    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 688    where
 689        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 690        Fut: 'static + Send + Future<Output = Result<()>>,
 691        M: EnvelopedMessage,
 692    {
 693        let prev_handler = self.handlers.insert(
 694            TypeId::of::<M>(),
 695            Box::new(move |envelope, session, span| {
 696                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 697                let received_at = envelope.received_at;
 698                tracing::info!("message received");
 699                let start_time = Instant::now();
 700                let future = (handler)(
 701                    *envelope,
 702                    MessageContext {
 703                        session,
 704                        span: span.clone(),
 705                    },
 706                );
 707                async move {
 708                    let result = future.await;
 709                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 710                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 711                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 712                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 713                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 714                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 715                    match result {
 716                        Err(error) => {
 717                            tracing::error!(?error, "error handling message")
 718                        }
 719                        Ok(()) => tracing::info!("finished handling message"),
 720                    }
 721                }
 722                .boxed()
 723            }),
 724        );
 725        if prev_handler.is_some() {
 726            panic!("registered a handler for the same message twice");
 727        }
 728        self
 729    }
 730
 731    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 732    where
 733        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 734        Fut: 'static + Send + Future<Output = Result<()>>,
 735        M: EnvelopedMessage,
 736    {
 737        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 738        self
 739    }
 740
 741    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 742    where
 743        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 744        Fut: Send + Future<Output = Result<()>>,
 745        M: RequestMessage,
 746    {
 747        let handler = Arc::new(handler);
 748        self.add_handler(move |envelope, session| {
 749            let receipt = envelope.receipt();
 750            let handler = handler.clone();
 751            async move {
 752                let peer = session.peer.clone();
 753                let responded = Arc::new(AtomicBool::default());
 754                let response = Response {
 755                    peer: peer.clone(),
 756                    responded: responded.clone(),
 757                    receipt,
 758                };
 759                match (handler)(envelope.payload, response, session).await {
 760                    Ok(()) => {
 761                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 762                            Ok(())
 763                        } else {
 764                            Err(anyhow!("handler did not send a response"))?
 765                        }
 766                    }
 767                    Err(error) => {
 768                        let proto_err = match &error {
 769                            Error::Internal(err) => err.to_proto(),
 770                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 771                        };
 772                        peer.respond_with_error(receipt, proto_err)?;
 773                        Err(error)
 774                    }
 775                }
 776            }
 777        })
 778    }
 779
 780    pub fn handle_connection(
 781        self: &Arc<Self>,
 782        connection: Connection,
 783        address: String,
 784        principal: Principal,
 785        zed_version: ZedVersion,
 786        release_channel: Option<String>,
 787        user_agent: Option<String>,
 788        geoip_country_code: Option<String>,
 789        system_id: Option<String>,
 790        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 791        executor: Executor,
 792        connection_guard: Option<ConnectionGuard>,
 793    ) -> impl Future<Output = ()> + use<> {
 794        let this = self.clone();
 795        let span = info_span!("handle connection", %address,
 796            connection_id=field::Empty,
 797            user_id=field::Empty,
 798            login=field::Empty,
 799            impersonator=field::Empty,
 800            user_agent=field::Empty,
 801            geoip_country_code=field::Empty,
 802            release_channel=field::Empty,
 803        );
 804        principal.update_span(&span);
 805        if let Some(user_agent) = user_agent {
 806            span.record("user_agent", user_agent);
 807        }
 808        if let Some(release_channel) = release_channel {
 809            span.record("release_channel", release_channel);
 810        }
 811
 812        if let Some(country_code) = geoip_country_code.as_ref() {
 813            span.record("geoip_country_code", country_code);
 814        }
 815
 816        let mut teardown = self.teardown.subscribe();
 817        async move {
 818            if *teardown.borrow() {
 819                tracing::error!("server is tearing down");
 820                return;
 821            }
 822
 823            let (connection_id, handle_io, mut incoming_rx) =
 824                this.peer.add_connection(connection, {
 825                    let executor = executor.clone();
 826                    move |duration| executor.sleep(duration)
 827                });
 828            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 829
 830            tracing::info!("connection opened");
 831
 832            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 833            let http_client = match ReqwestClient::user_agent(&user_agent) {
 834                Ok(http_client) => Arc::new(http_client),
 835                Err(error) => {
 836                    tracing::error!(?error, "failed to create HTTP client");
 837                    return;
 838                }
 839            };
 840
 841            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 842                |supermaven_admin_api_key| {
 843                    Arc::new(SupermavenAdminApi::new(
 844                        supermaven_admin_api_key.to_string(),
 845                        http_client.clone(),
 846                    ))
 847                },
 848            );
 849
 850            let session = Session {
 851                principal: principal.clone(),
 852                connection_id,
 853                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 854                peer: this.peer.clone(),
 855                connection_pool: this.connection_pool.clone(),
 856                app_state: this.app_state.clone(),
 857                geoip_country_code,
 858                system_id,
 859                _executor: executor.clone(),
 860                supermaven_client,
 861            };
 862
 863            if let Err(error) = this
 864                .send_initial_client_update(
 865                    connection_id,
 866                    zed_version,
 867                    send_connection_id,
 868                    &session,
 869                )
 870                .await
 871            {
 872                tracing::error!(?error, "failed to send initial client update");
 873                return;
 874            }
 875            drop(connection_guard);
 876
 877            let handle_io = handle_io.fuse();
 878            futures::pin_mut!(handle_io);
 879
 880            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 881            // This prevents deadlocks when e.g., client A performs a request to client B and
 882            // client B performs a request to client A. If both clients stop processing further
 883            // messages until their respective request completes, they won't have a chance to
 884            // respond to the other client's request and cause a deadlock.
 885            //
 886            // This arrangement ensures we will attempt to process earlier messages first, but fall
 887            // back to processing messages arrived later in the spirit of making progress.
 888            const MAX_CONCURRENT_HANDLERS: usize = 256;
 889            let mut foreground_message_handlers = FuturesUnordered::new();
 890            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 891            let get_concurrent_handlers = {
 892                let concurrent_handlers = concurrent_handlers.clone();
 893                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 894            };
 895            loop {
 896                let next_message = async {
 897                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 898                    let message = incoming_rx.next().await;
 899                    // Cache the concurrent_handlers here, so that we know what the
 900                    // queue looks like as each handler starts
 901                    (permit, message, get_concurrent_handlers())
 902                }
 903                .fuse();
 904                futures::pin_mut!(next_message);
 905                futures::select_biased! {
 906                    _ = teardown.changed().fuse() => return,
 907                    result = handle_io => {
 908                        if let Err(error) = result {
 909                            tracing::error!(?error, "error handling I/O");
 910                        }
 911                        break;
 912                    }
 913                    _ = foreground_message_handlers.next() => {}
 914                    next_message = next_message => {
 915                        let (permit, message, concurrent_handlers) = next_message;
 916                        if let Some(message) = message {
 917                            let type_name = message.payload_type_name();
 918                            // note: we copy all the fields from the parent span so we can query them in the logs.
 919                            // (https://github.com/tokio-rs/tracing/issues/2670).
 920                            let span = tracing::info_span!("receive message",
 921                                %connection_id,
 922                                %address,
 923                                type_name,
 924                                concurrent_handlers,
 925                                user_id=field::Empty,
 926                                login=field::Empty,
 927                                impersonator=field::Empty,
 928                                multi_lsp_query_request=field::Empty,
 929                                release_channel=field::Empty,
 930                                { TOTAL_DURATION_MS }=field::Empty,
 931                                { PROCESSING_DURATION_MS }=field::Empty,
 932                                { QUEUE_DURATION_MS }=field::Empty,
 933                                { HOST_WAITING_MS }=field::Empty
 934                            );
 935                            principal.update_span(&span);
 936                            let span_enter = span.enter();
 937                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 938                                let is_background = message.is_background();
 939                                let handle_message = (handler)(message, session.clone(), span.clone());
 940                                drop(span_enter);
 941
 942                                let handle_message = async move {
 943                                    handle_message.await;
 944                                    drop(permit);
 945                                }.instrument(span);
 946                                if is_background {
 947                                    executor.spawn_detached(handle_message);
 948                                } else {
 949                                    foreground_message_handlers.push(handle_message);
 950                                }
 951                            } else {
 952                                tracing::error!("no message handler");
 953                            }
 954                        } else {
 955                            tracing::info!("connection closed");
 956                            break;
 957                        }
 958                    }
 959                }
 960            }
 961
 962            drop(foreground_message_handlers);
 963            let concurrent_handlers = get_concurrent_handlers();
 964            tracing::info!(concurrent_handlers, "signing out");
 965            if let Err(error) = connection_lost(session, teardown, executor).await {
 966                tracing::error!(?error, "error signing out");
 967            }
 968        }
 969        .instrument(span)
 970    }
 971
 972    async fn send_initial_client_update(
 973        &self,
 974        connection_id: ConnectionId,
 975        zed_version: ZedVersion,
 976        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 977        session: &Session,
 978    ) -> Result<()> {
 979        self.peer.send(
 980            connection_id,
 981            proto::Hello {
 982                peer_id: Some(connection_id.into()),
 983            },
 984        )?;
 985        tracing::info!("sent hello message");
 986        if let Some(send_connection_id) = send_connection_id.take() {
 987            let _ = send_connection_id.send(connection_id);
 988        }
 989
 990        match &session.principal {
 991            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 992                if !user.connected_once {
 993                    self.peer.send(connection_id, proto::ShowContacts {})?;
 994                    self.app_state
 995                        .db
 996                        .set_user_connected_once(user.id, true)
 997                        .await?;
 998                }
 999
1000                update_user_plan(session).await?;
1001
1002                let contacts = self.app_state.db.get_contacts(user.id).await?;
1003
1004                {
1005                    let mut pool = self.connection_pool.lock();
1006                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1007                    self.peer.send(
1008                        connection_id,
1009                        build_initial_contacts_update(contacts, &pool),
1010                    )?;
1011                }
1012
1013                if should_auto_subscribe_to_channels(zed_version) {
1014                    subscribe_user_to_channels(user.id, session).await?;
1015                }
1016
1017                if let Some(incoming_call) =
1018                    self.app_state.db.incoming_call_for_user(user.id).await?
1019                {
1020                    self.peer.send(connection_id, incoming_call)?;
1021                }
1022
1023                update_user_contacts(user.id, session).await?;
1024            }
1025        }
1026
1027        Ok(())
1028    }
1029
1030    pub async fn invite_code_redeemed(
1031        self: &Arc<Self>,
1032        inviter_id: UserId,
1033        invitee_id: UserId,
1034    ) -> Result<()> {
1035        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1036            if let Some(code) = &user.invite_code {
1037                let pool = self.connection_pool.lock();
1038                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1039                for connection_id in pool.user_connection_ids(inviter_id) {
1040                    self.peer.send(
1041                        connection_id,
1042                        proto::UpdateContacts {
1043                            contacts: vec![invitee_contact.clone()],
1044                            ..Default::default()
1045                        },
1046                    )?;
1047                    self.peer.send(
1048                        connection_id,
1049                        proto::UpdateInviteInfo {
1050                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1051                            count: user.invite_count as u32,
1052                        },
1053                    )?;
1054                }
1055            }
1056        }
1057        Ok(())
1058    }
1059
1060    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1061        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1062            if let Some(invite_code) = &user.invite_code {
1063                let pool = self.connection_pool.lock();
1064                for connection_id in pool.user_connection_ids(user_id) {
1065                    self.peer.send(
1066                        connection_id,
1067                        proto::UpdateInviteInfo {
1068                            url: format!(
1069                                "{}{}",
1070                                self.app_state.config.invite_link_prefix, invite_code
1071                            ),
1072                            count: user.invite_count as u32,
1073                        },
1074                    )?;
1075                }
1076            }
1077        }
1078        Ok(())
1079    }
1080
1081    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1082        ServerSnapshot {
1083            connection_pool: ConnectionPoolGuard {
1084                guard: self.connection_pool.lock(),
1085                _not_send: PhantomData,
1086            },
1087            peer: &self.peer,
1088        }
1089    }
1090}
1091
1092impl Deref for ConnectionPoolGuard<'_> {
1093    type Target = ConnectionPool;
1094
1095    fn deref(&self) -> &Self::Target {
1096        &self.guard
1097    }
1098}
1099
1100impl DerefMut for ConnectionPoolGuard<'_> {
1101    fn deref_mut(&mut self) -> &mut Self::Target {
1102        &mut self.guard
1103    }
1104}
1105
1106impl Drop for ConnectionPoolGuard<'_> {
1107    fn drop(&mut self) {
1108        #[cfg(test)]
1109        self.check_invariants();
1110    }
1111}
1112
1113fn broadcast<F>(
1114    sender_id: Option<ConnectionId>,
1115    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1116    mut f: F,
1117) where
1118    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1119{
1120    for receiver_id in receiver_ids {
1121        if Some(receiver_id) != sender_id {
1122            if let Err(error) = f(receiver_id) {
1123                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1124            }
1125        }
1126    }
1127}
1128
1129pub struct ProtocolVersion(u32);
1130
1131impl Header for ProtocolVersion {
1132    fn name() -> &'static HeaderName {
1133        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1134        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1135    }
1136
1137    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1138    where
1139        Self: Sized,
1140        I: Iterator<Item = &'i axum::http::HeaderValue>,
1141    {
1142        let version = values
1143            .next()
1144            .ok_or_else(axum::headers::Error::invalid)?
1145            .to_str()
1146            .map_err(|_| axum::headers::Error::invalid())?
1147            .parse()
1148            .map_err(|_| axum::headers::Error::invalid())?;
1149        Ok(Self(version))
1150    }
1151
1152    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1153        values.extend([self.0.to_string().parse().unwrap()]);
1154    }
1155}
1156
1157pub struct AppVersionHeader(SemanticVersion);
1158impl Header for AppVersionHeader {
1159    fn name() -> &'static HeaderName {
1160        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1161        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1162    }
1163
1164    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1165    where
1166        Self: Sized,
1167        I: Iterator<Item = &'i axum::http::HeaderValue>,
1168    {
1169        let version = values
1170            .next()
1171            .ok_or_else(axum::headers::Error::invalid)?
1172            .to_str()
1173            .map_err(|_| axum::headers::Error::invalid())?
1174            .parse()
1175            .map_err(|_| axum::headers::Error::invalid())?;
1176        Ok(Self(version))
1177    }
1178
1179    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1180        values.extend([self.0.to_string().parse().unwrap()]);
1181    }
1182}
1183
1184#[derive(Debug)]
1185pub struct ReleaseChannelHeader(String);
1186
1187impl Header for ReleaseChannelHeader {
1188    fn name() -> &'static HeaderName {
1189        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1190        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1191    }
1192
1193    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1194    where
1195        Self: Sized,
1196        I: Iterator<Item = &'i axum::http::HeaderValue>,
1197    {
1198        Ok(Self(
1199            values
1200                .next()
1201                .ok_or_else(axum::headers::Error::invalid)?
1202                .to_str()
1203                .map_err(|_| axum::headers::Error::invalid())?
1204                .to_owned(),
1205        ))
1206    }
1207
1208    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1209        values.extend([self.0.parse().unwrap()]);
1210    }
1211}
1212
1213pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1214    Router::new()
1215        .route("/rpc", get(handle_websocket_request))
1216        .layer(
1217            ServiceBuilder::new()
1218                .layer(Extension(server.app_state.clone()))
1219                .layer(middleware::from_fn(auth::validate_header)),
1220        )
1221        .route("/metrics", get(handle_metrics))
1222        .layer(Extension(server))
1223}
1224
1225pub async fn handle_websocket_request(
1226    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1227    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1228    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1229    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1230    Extension(server): Extension<Arc<Server>>,
1231    Extension(principal): Extension<Principal>,
1232    user_agent: Option<TypedHeader<UserAgent>>,
1233    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1234    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1235    ws: WebSocketUpgrade,
1236) -> axum::response::Response {
1237    if protocol_version != rpc::PROTOCOL_VERSION {
1238        return (
1239            StatusCode::UPGRADE_REQUIRED,
1240            "client must be upgraded".to_string(),
1241        )
1242            .into_response();
1243    }
1244
1245    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1246        return (
1247            StatusCode::UPGRADE_REQUIRED,
1248            "no version header found".to_string(),
1249        )
1250            .into_response();
1251    };
1252
1253    let release_channel = release_channel_header.map(|header| header.0.0);
1254
1255    if !version.can_collaborate() {
1256        return (
1257            StatusCode::UPGRADE_REQUIRED,
1258            "client must be upgraded".to_string(),
1259        )
1260            .into_response();
1261    }
1262
1263    let socket_address = socket_address.to_string();
1264
1265    // Acquire connection guard before WebSocket upgrade
1266    let connection_guard = match ConnectionGuard::try_acquire() {
1267        Ok(guard) => guard,
1268        Err(()) => {
1269            return (
1270                StatusCode::SERVICE_UNAVAILABLE,
1271                "Too many concurrent connections",
1272            )
1273                .into_response();
1274        }
1275    };
1276
1277    ws.on_upgrade(move |socket| {
1278        let socket = socket
1279            .map_ok(to_tungstenite_message)
1280            .err_into()
1281            .with(|message| async move { to_axum_message(message) });
1282        let connection = Connection::new(Box::pin(socket));
1283        async move {
1284            server
1285                .handle_connection(
1286                    connection,
1287                    socket_address,
1288                    principal,
1289                    version,
1290                    release_channel,
1291                    user_agent.map(|header| header.to_string()),
1292                    country_code_header.map(|header| header.to_string()),
1293                    system_id_header.map(|header| header.to_string()),
1294                    None,
1295                    Executor::Production,
1296                    Some(connection_guard),
1297                )
1298                .await;
1299        }
1300    })
1301}
1302
1303pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1304    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1305    let connections_metric = CONNECTIONS_METRIC
1306        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1307
1308    let connections = server
1309        .connection_pool
1310        .lock()
1311        .connections()
1312        .filter(|connection| !connection.admin)
1313        .count();
1314    connections_metric.set(connections as _);
1315
1316    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1317    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1318        register_int_gauge!(
1319            "shared_projects",
1320            "number of open projects with one or more guests"
1321        )
1322        .unwrap()
1323    });
1324
1325    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1326    shared_projects_metric.set(shared_projects as _);
1327
1328    let encoder = prometheus::TextEncoder::new();
1329    let metric_families = prometheus::gather();
1330    let encoded_metrics = encoder
1331        .encode_to_string(&metric_families)
1332        .map_err(|err| anyhow!("{err}"))?;
1333    Ok(encoded_metrics)
1334}
1335
1336#[instrument(err, skip(executor))]
1337async fn connection_lost(
1338    session: Session,
1339    mut teardown: watch::Receiver<bool>,
1340    executor: Executor,
1341) -> Result<()> {
1342    session.peer.disconnect(session.connection_id);
1343    session
1344        .connection_pool()
1345        .await
1346        .remove_connection(session.connection_id)?;
1347
1348    session
1349        .db()
1350        .await
1351        .connection_lost(session.connection_id)
1352        .await
1353        .trace_err();
1354
1355    futures::select_biased! {
1356        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1357
1358            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1359            leave_room_for_session(&session, session.connection_id).await.trace_err();
1360            leave_channel_buffers_for_session(&session)
1361                .await
1362                .trace_err();
1363
1364            if !session
1365                .connection_pool()
1366                .await
1367                .is_user_online(session.user_id())
1368            {
1369                let db = session.db().await;
1370                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1371                    room_updated(&room, &session.peer);
1372                }
1373            }
1374
1375            update_user_contacts(session.user_id(), &session).await?;
1376        },
1377        _ = teardown.changed().fuse() => {}
1378    }
1379
1380    Ok(())
1381}
1382
1383/// Acknowledges a ping from a client, used to keep the connection alive.
1384async fn ping(
1385    _: proto::Ping,
1386    response: Response<proto::Ping>,
1387    _session: MessageContext,
1388) -> Result<()> {
1389    response.send(proto::Ack {})?;
1390    Ok(())
1391}
1392
1393/// Creates a new room for calling (outside of channels)
1394async fn create_room(
1395    _request: proto::CreateRoom,
1396    response: Response<proto::CreateRoom>,
1397    session: MessageContext,
1398) -> Result<()> {
1399    let livekit_room = nanoid::nanoid!(30);
1400
1401    let live_kit_connection_info = util::maybe!(async {
1402        let live_kit = session.app_state.livekit_client.as_ref();
1403        let live_kit = live_kit?;
1404        let user_id = session.user_id().to_string();
1405
1406        let token = live_kit
1407            .room_token(&livekit_room, &user_id.to_string())
1408            .trace_err()?;
1409
1410        Some(proto::LiveKitConnectionInfo {
1411            server_url: live_kit.url().into(),
1412            token,
1413            can_publish: true,
1414        })
1415    })
1416    .await;
1417
1418    let room = session
1419        .db()
1420        .await
1421        .create_room(session.user_id(), session.connection_id, &livekit_room)
1422        .await?;
1423
1424    response.send(proto::CreateRoomResponse {
1425        room: Some(room.clone()),
1426        live_kit_connection_info,
1427    })?;
1428
1429    update_user_contacts(session.user_id(), &session).await?;
1430    Ok(())
1431}
1432
1433/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1434async fn join_room(
1435    request: proto::JoinRoom,
1436    response: Response<proto::JoinRoom>,
1437    session: MessageContext,
1438) -> Result<()> {
1439    let room_id = RoomId::from_proto(request.id);
1440
1441    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1442
1443    if let Some(channel_id) = channel_id {
1444        return join_channel_internal(channel_id, Box::new(response), session).await;
1445    }
1446
1447    let joined_room = {
1448        let room = session
1449            .db()
1450            .await
1451            .join_room(room_id, session.user_id(), session.connection_id)
1452            .await?;
1453        room_updated(&room.room, &session.peer);
1454        room.into_inner()
1455    };
1456
1457    for connection_id in session
1458        .connection_pool()
1459        .await
1460        .user_connection_ids(session.user_id())
1461    {
1462        session
1463            .peer
1464            .send(
1465                connection_id,
1466                proto::CallCanceled {
1467                    room_id: room_id.to_proto(),
1468                },
1469            )
1470            .trace_err();
1471    }
1472
1473    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1474    {
1475        live_kit
1476            .room_token(
1477                &joined_room.room.livekit_room,
1478                &session.user_id().to_string(),
1479            )
1480            .trace_err()
1481            .map(|token| proto::LiveKitConnectionInfo {
1482                server_url: live_kit.url().into(),
1483                token,
1484                can_publish: true,
1485            })
1486    } else {
1487        None
1488    };
1489
1490    response.send(proto::JoinRoomResponse {
1491        room: Some(joined_room.room),
1492        channel_id: None,
1493        live_kit_connection_info,
1494    })?;
1495
1496    update_user_contacts(session.user_id(), &session).await?;
1497    Ok(())
1498}
1499
1500/// Rejoin room is used to reconnect to a room after connection errors.
1501async fn rejoin_room(
1502    request: proto::RejoinRoom,
1503    response: Response<proto::RejoinRoom>,
1504    session: MessageContext,
1505) -> Result<()> {
1506    let room;
1507    let channel;
1508    {
1509        let mut rejoined_room = session
1510            .db()
1511            .await
1512            .rejoin_room(request, session.user_id(), session.connection_id)
1513            .await?;
1514
1515        response.send(proto::RejoinRoomResponse {
1516            room: Some(rejoined_room.room.clone()),
1517            reshared_projects: rejoined_room
1518                .reshared_projects
1519                .iter()
1520                .map(|project| proto::ResharedProject {
1521                    id: project.id.to_proto(),
1522                    collaborators: project
1523                        .collaborators
1524                        .iter()
1525                        .map(|collaborator| collaborator.to_proto())
1526                        .collect(),
1527                })
1528                .collect(),
1529            rejoined_projects: rejoined_room
1530                .rejoined_projects
1531                .iter()
1532                .map(|rejoined_project| rejoined_project.to_proto())
1533                .collect(),
1534        })?;
1535        room_updated(&rejoined_room.room, &session.peer);
1536
1537        for project in &rejoined_room.reshared_projects {
1538            for collaborator in &project.collaborators {
1539                session
1540                    .peer
1541                    .send(
1542                        collaborator.connection_id,
1543                        proto::UpdateProjectCollaborator {
1544                            project_id: project.id.to_proto(),
1545                            old_peer_id: Some(project.old_connection_id.into()),
1546                            new_peer_id: Some(session.connection_id.into()),
1547                        },
1548                    )
1549                    .trace_err();
1550            }
1551
1552            broadcast(
1553                Some(session.connection_id),
1554                project
1555                    .collaborators
1556                    .iter()
1557                    .map(|collaborator| collaborator.connection_id),
1558                |connection_id| {
1559                    session.peer.forward_send(
1560                        session.connection_id,
1561                        connection_id,
1562                        proto::UpdateProject {
1563                            project_id: project.id.to_proto(),
1564                            worktrees: project.worktrees.clone(),
1565                        },
1566                    )
1567                },
1568            );
1569        }
1570
1571        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1572
1573        let rejoined_room = rejoined_room.into_inner();
1574
1575        room = rejoined_room.room;
1576        channel = rejoined_room.channel;
1577    }
1578
1579    if let Some(channel) = channel {
1580        channel_updated(
1581            &channel,
1582            &room,
1583            &session.peer,
1584            &*session.connection_pool().await,
1585        );
1586    }
1587
1588    update_user_contacts(session.user_id(), &session).await?;
1589    Ok(())
1590}
1591
1592fn notify_rejoined_projects(
1593    rejoined_projects: &mut Vec<RejoinedProject>,
1594    session: &Session,
1595) -> Result<()> {
1596    for project in rejoined_projects.iter() {
1597        for collaborator in &project.collaborators {
1598            session
1599                .peer
1600                .send(
1601                    collaborator.connection_id,
1602                    proto::UpdateProjectCollaborator {
1603                        project_id: project.id.to_proto(),
1604                        old_peer_id: Some(project.old_connection_id.into()),
1605                        new_peer_id: Some(session.connection_id.into()),
1606                    },
1607                )
1608                .trace_err();
1609        }
1610    }
1611
1612    for project in rejoined_projects {
1613        for worktree in mem::take(&mut project.worktrees) {
1614            // Stream this worktree's entries.
1615            let message = proto::UpdateWorktree {
1616                project_id: project.id.to_proto(),
1617                worktree_id: worktree.id,
1618                abs_path: worktree.abs_path.clone(),
1619                root_name: worktree.root_name,
1620                updated_entries: worktree.updated_entries,
1621                removed_entries: worktree.removed_entries,
1622                scan_id: worktree.scan_id,
1623                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1624                updated_repositories: worktree.updated_repositories,
1625                removed_repositories: worktree.removed_repositories,
1626            };
1627            for update in proto::split_worktree_update(message) {
1628                session.peer.send(session.connection_id, update)?;
1629            }
1630
1631            // Stream this worktree's diagnostics.
1632            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1633            if let Some(summary) = worktree_diagnostics.next() {
1634                let message = proto::UpdateDiagnosticSummary {
1635                    project_id: project.id.to_proto(),
1636                    worktree_id: worktree.id,
1637                    summary: Some(summary),
1638                    more_summaries: worktree_diagnostics.collect(),
1639                };
1640                session.peer.send(session.connection_id, message)?;
1641            }
1642
1643            for settings_file in worktree.settings_files {
1644                session.peer.send(
1645                    session.connection_id,
1646                    proto::UpdateWorktreeSettings {
1647                        project_id: project.id.to_proto(),
1648                        worktree_id: worktree.id,
1649                        path: settings_file.path,
1650                        content: Some(settings_file.content),
1651                        kind: Some(settings_file.kind.to_proto().into()),
1652                    },
1653                )?;
1654            }
1655        }
1656
1657        for repository in mem::take(&mut project.updated_repositories) {
1658            for update in split_repository_update(repository) {
1659                session.peer.send(session.connection_id, update)?;
1660            }
1661        }
1662
1663        for id in mem::take(&mut project.removed_repositories) {
1664            session.peer.send(
1665                session.connection_id,
1666                proto::RemoveRepository {
1667                    project_id: project.id.to_proto(),
1668                    id,
1669                },
1670            )?;
1671        }
1672    }
1673
1674    Ok(())
1675}
1676
1677/// leave room disconnects from the room.
1678async fn leave_room(
1679    _: proto::LeaveRoom,
1680    response: Response<proto::LeaveRoom>,
1681    session: MessageContext,
1682) -> Result<()> {
1683    leave_room_for_session(&session, session.connection_id).await?;
1684    response.send(proto::Ack {})?;
1685    Ok(())
1686}
1687
1688/// Updates the permissions of someone else in the room.
1689async fn set_room_participant_role(
1690    request: proto::SetRoomParticipantRole,
1691    response: Response<proto::SetRoomParticipantRole>,
1692    session: MessageContext,
1693) -> Result<()> {
1694    let user_id = UserId::from_proto(request.user_id);
1695    let role = ChannelRole::from(request.role());
1696
1697    let (livekit_room, can_publish) = {
1698        let room = session
1699            .db()
1700            .await
1701            .set_room_participant_role(
1702                session.user_id(),
1703                RoomId::from_proto(request.room_id),
1704                user_id,
1705                role,
1706            )
1707            .await?;
1708
1709        let livekit_room = room.livekit_room.clone();
1710        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1711        room_updated(&room, &session.peer);
1712        (livekit_room, can_publish)
1713    };
1714
1715    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1716        live_kit
1717            .update_participant(
1718                livekit_room.clone(),
1719                request.user_id.to_string(),
1720                livekit_api::proto::ParticipantPermission {
1721                    can_subscribe: true,
1722                    can_publish,
1723                    can_publish_data: can_publish,
1724                    hidden: false,
1725                    recorder: false,
1726                },
1727            )
1728            .await
1729            .trace_err();
1730    }
1731
1732    response.send(proto::Ack {})?;
1733    Ok(())
1734}
1735
1736/// Call someone else into the current room
1737async fn call(
1738    request: proto::Call,
1739    response: Response<proto::Call>,
1740    session: MessageContext,
1741) -> Result<()> {
1742    let room_id = RoomId::from_proto(request.room_id);
1743    let calling_user_id = session.user_id();
1744    let calling_connection_id = session.connection_id;
1745    let called_user_id = UserId::from_proto(request.called_user_id);
1746    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1747    if !session
1748        .db()
1749        .await
1750        .has_contact(calling_user_id, called_user_id)
1751        .await?
1752    {
1753        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1754    }
1755
1756    let incoming_call = {
1757        let (room, incoming_call) = &mut *session
1758            .db()
1759            .await
1760            .call(
1761                room_id,
1762                calling_user_id,
1763                calling_connection_id,
1764                called_user_id,
1765                initial_project_id,
1766            )
1767            .await?;
1768        room_updated(room, &session.peer);
1769        mem::take(incoming_call)
1770    };
1771    update_user_contacts(called_user_id, &session).await?;
1772
1773    let mut calls = session
1774        .connection_pool()
1775        .await
1776        .user_connection_ids(called_user_id)
1777        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1778        .collect::<FuturesUnordered<_>>();
1779
1780    while let Some(call_response) = calls.next().await {
1781        match call_response.as_ref() {
1782            Ok(_) => {
1783                response.send(proto::Ack {})?;
1784                return Ok(());
1785            }
1786            Err(_) => {
1787                call_response.trace_err();
1788            }
1789        }
1790    }
1791
1792    {
1793        let room = session
1794            .db()
1795            .await
1796            .call_failed(room_id, called_user_id)
1797            .await?;
1798        room_updated(&room, &session.peer);
1799    }
1800    update_user_contacts(called_user_id, &session).await?;
1801
1802    Err(anyhow!("failed to ring user"))?
1803}
1804
1805/// Cancel an outgoing call.
1806async fn cancel_call(
1807    request: proto::CancelCall,
1808    response: Response<proto::CancelCall>,
1809    session: MessageContext,
1810) -> Result<()> {
1811    let called_user_id = UserId::from_proto(request.called_user_id);
1812    let room_id = RoomId::from_proto(request.room_id);
1813    {
1814        let room = session
1815            .db()
1816            .await
1817            .cancel_call(room_id, session.connection_id, called_user_id)
1818            .await?;
1819        room_updated(&room, &session.peer);
1820    }
1821
1822    for connection_id in session
1823        .connection_pool()
1824        .await
1825        .user_connection_ids(called_user_id)
1826    {
1827        session
1828            .peer
1829            .send(
1830                connection_id,
1831                proto::CallCanceled {
1832                    room_id: room_id.to_proto(),
1833                },
1834            )
1835            .trace_err();
1836    }
1837    response.send(proto::Ack {})?;
1838
1839    update_user_contacts(called_user_id, &session).await?;
1840    Ok(())
1841}
1842
1843/// Decline an incoming call.
1844async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1845    let room_id = RoomId::from_proto(message.room_id);
1846    {
1847        let room = session
1848            .db()
1849            .await
1850            .decline_call(Some(room_id), session.user_id())
1851            .await?
1852            .context("declining call")?;
1853        room_updated(&room, &session.peer);
1854    }
1855
1856    for connection_id in session
1857        .connection_pool()
1858        .await
1859        .user_connection_ids(session.user_id())
1860    {
1861        session
1862            .peer
1863            .send(
1864                connection_id,
1865                proto::CallCanceled {
1866                    room_id: room_id.to_proto(),
1867                },
1868            )
1869            .trace_err();
1870    }
1871    update_user_contacts(session.user_id(), &session).await?;
1872    Ok(())
1873}
1874
1875/// Updates other participants in the room with your current location.
1876async fn update_participant_location(
1877    request: proto::UpdateParticipantLocation,
1878    response: Response<proto::UpdateParticipantLocation>,
1879    session: MessageContext,
1880) -> Result<()> {
1881    let room_id = RoomId::from_proto(request.room_id);
1882    let location = request.location.context("invalid location")?;
1883
1884    let db = session.db().await;
1885    let room = db
1886        .update_room_participant_location(room_id, session.connection_id, location)
1887        .await?;
1888
1889    room_updated(&room, &session.peer);
1890    response.send(proto::Ack {})?;
1891    Ok(())
1892}
1893
1894/// Share a project into the room.
1895async fn share_project(
1896    request: proto::ShareProject,
1897    response: Response<proto::ShareProject>,
1898    session: MessageContext,
1899) -> Result<()> {
1900    let (project_id, room) = &*session
1901        .db()
1902        .await
1903        .share_project(
1904            RoomId::from_proto(request.room_id),
1905            session.connection_id,
1906            &request.worktrees,
1907            request.is_ssh_project,
1908        )
1909        .await?;
1910    response.send(proto::ShareProjectResponse {
1911        project_id: project_id.to_proto(),
1912    })?;
1913    room_updated(room, &session.peer);
1914
1915    Ok(())
1916}
1917
1918/// Unshare a project from the room.
1919async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1920    let project_id = ProjectId::from_proto(message.project_id);
1921    unshare_project_internal(project_id, session.connection_id, &session).await
1922}
1923
1924async fn unshare_project_internal(
1925    project_id: ProjectId,
1926    connection_id: ConnectionId,
1927    session: &Session,
1928) -> Result<()> {
1929    let delete = {
1930        let room_guard = session
1931            .db()
1932            .await
1933            .unshare_project(project_id, connection_id)
1934            .await?;
1935
1936        let (delete, room, guest_connection_ids) = &*room_guard;
1937
1938        let message = proto::UnshareProject {
1939            project_id: project_id.to_proto(),
1940        };
1941
1942        broadcast(
1943            Some(connection_id),
1944            guest_connection_ids.iter().copied(),
1945            |conn_id| session.peer.send(conn_id, message.clone()),
1946        );
1947        if let Some(room) = room {
1948            room_updated(room, &session.peer);
1949        }
1950
1951        *delete
1952    };
1953
1954    if delete {
1955        let db = session.db().await;
1956        db.delete_project(project_id).await?;
1957    }
1958
1959    Ok(())
1960}
1961
1962/// Join someone elses shared project.
1963async fn join_project(
1964    request: proto::JoinProject,
1965    response: Response<proto::JoinProject>,
1966    session: MessageContext,
1967) -> Result<()> {
1968    let project_id = ProjectId::from_proto(request.project_id);
1969
1970    tracing::info!(%project_id, "join project");
1971
1972    let db = session.db().await;
1973    let (project, replica_id) = &mut *db
1974        .join_project(
1975            project_id,
1976            session.connection_id,
1977            session.user_id(),
1978            request.committer_name.clone(),
1979            request.committer_email.clone(),
1980        )
1981        .await?;
1982    drop(db);
1983    tracing::info!(%project_id, "join remote project");
1984    let collaborators = project
1985        .collaborators
1986        .iter()
1987        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1988        .map(|collaborator| collaborator.to_proto())
1989        .collect::<Vec<_>>();
1990    let project_id = project.id;
1991    let guest_user_id = session.user_id();
1992
1993    let worktrees = project
1994        .worktrees
1995        .iter()
1996        .map(|(id, worktree)| proto::WorktreeMetadata {
1997            id: *id,
1998            root_name: worktree.root_name.clone(),
1999            visible: worktree.visible,
2000            abs_path: worktree.abs_path.clone(),
2001        })
2002        .collect::<Vec<_>>();
2003
2004    let add_project_collaborator = proto::AddProjectCollaborator {
2005        project_id: project_id.to_proto(),
2006        collaborator: Some(proto::Collaborator {
2007            peer_id: Some(session.connection_id.into()),
2008            replica_id: replica_id.0 as u32,
2009            user_id: guest_user_id.to_proto(),
2010            is_host: false,
2011            committer_name: request.committer_name.clone(),
2012            committer_email: request.committer_email.clone(),
2013        }),
2014    };
2015
2016    for collaborator in &collaborators {
2017        session
2018            .peer
2019            .send(
2020                collaborator.peer_id.unwrap().into(),
2021                add_project_collaborator.clone(),
2022            )
2023            .trace_err();
2024    }
2025
2026    // First, we send the metadata associated with each worktree.
2027    let (language_servers, language_server_capabilities) = project
2028        .language_servers
2029        .clone()
2030        .into_iter()
2031        .map(|server| (server.server, server.capabilities))
2032        .unzip();
2033    response.send(proto::JoinProjectResponse {
2034        project_id: project.id.0 as u64,
2035        worktrees: worktrees.clone(),
2036        replica_id: replica_id.0 as u32,
2037        collaborators: collaborators.clone(),
2038        language_servers,
2039        language_server_capabilities,
2040        role: project.role.into(),
2041    })?;
2042
2043    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2044        // Stream this worktree's entries.
2045        let message = proto::UpdateWorktree {
2046            project_id: project_id.to_proto(),
2047            worktree_id,
2048            abs_path: worktree.abs_path.clone(),
2049            root_name: worktree.root_name,
2050            updated_entries: worktree.entries,
2051            removed_entries: Default::default(),
2052            scan_id: worktree.scan_id,
2053            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2054            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2055            removed_repositories: Default::default(),
2056        };
2057        for update in proto::split_worktree_update(message) {
2058            session.peer.send(session.connection_id, update.clone())?;
2059        }
2060
2061        // Stream this worktree's diagnostics.
2062        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2063        if let Some(summary) = worktree_diagnostics.next() {
2064            let message = proto::UpdateDiagnosticSummary {
2065                project_id: project.id.to_proto(),
2066                worktree_id: worktree.id,
2067                summary: Some(summary),
2068                more_summaries: worktree_diagnostics.collect(),
2069            };
2070            session.peer.send(session.connection_id, message)?;
2071        }
2072
2073        for settings_file in worktree.settings_files {
2074            session.peer.send(
2075                session.connection_id,
2076                proto::UpdateWorktreeSettings {
2077                    project_id: project_id.to_proto(),
2078                    worktree_id: worktree.id,
2079                    path: settings_file.path,
2080                    content: Some(settings_file.content),
2081                    kind: Some(settings_file.kind.to_proto() as i32),
2082                },
2083            )?;
2084        }
2085    }
2086
2087    for repository in mem::take(&mut project.repositories) {
2088        for update in split_repository_update(repository) {
2089            session.peer.send(session.connection_id, update)?;
2090        }
2091    }
2092
2093    for language_server in &project.language_servers {
2094        session.peer.send(
2095            session.connection_id,
2096            proto::UpdateLanguageServer {
2097                project_id: project_id.to_proto(),
2098                server_name: Some(language_server.server.name.clone()),
2099                language_server_id: language_server.server.id,
2100                variant: Some(
2101                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2102                        proto::LspDiskBasedDiagnosticsUpdated {},
2103                    ),
2104                ),
2105            },
2106        )?;
2107    }
2108
2109    Ok(())
2110}
2111
2112/// Leave someone elses shared project.
2113async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2114    let sender_id = session.connection_id;
2115    let project_id = ProjectId::from_proto(request.project_id);
2116    let db = session.db().await;
2117
2118    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2119    tracing::info!(
2120        %project_id,
2121        "leave project"
2122    );
2123
2124    project_left(project, &session);
2125    if let Some(room) = room {
2126        room_updated(room, &session.peer);
2127    }
2128
2129    Ok(())
2130}
2131
2132/// Updates other participants with changes to the project
2133async fn update_project(
2134    request: proto::UpdateProject,
2135    response: Response<proto::UpdateProject>,
2136    session: MessageContext,
2137) -> Result<()> {
2138    let project_id = ProjectId::from_proto(request.project_id);
2139    let (room, guest_connection_ids) = &*session
2140        .db()
2141        .await
2142        .update_project(project_id, session.connection_id, &request.worktrees)
2143        .await?;
2144    broadcast(
2145        Some(session.connection_id),
2146        guest_connection_ids.iter().copied(),
2147        |connection_id| {
2148            session
2149                .peer
2150                .forward_send(session.connection_id, connection_id, request.clone())
2151        },
2152    );
2153    if let Some(room) = room {
2154        room_updated(room, &session.peer);
2155    }
2156    response.send(proto::Ack {})?;
2157
2158    Ok(())
2159}
2160
2161/// Updates other participants with changes to the worktree
2162async fn update_worktree(
2163    request: proto::UpdateWorktree,
2164    response: Response<proto::UpdateWorktree>,
2165    session: MessageContext,
2166) -> Result<()> {
2167    let guest_connection_ids = session
2168        .db()
2169        .await
2170        .update_worktree(&request, session.connection_id)
2171        .await?;
2172
2173    broadcast(
2174        Some(session.connection_id),
2175        guest_connection_ids.iter().copied(),
2176        |connection_id| {
2177            session
2178                .peer
2179                .forward_send(session.connection_id, connection_id, request.clone())
2180        },
2181    );
2182    response.send(proto::Ack {})?;
2183    Ok(())
2184}
2185
2186async fn update_repository(
2187    request: proto::UpdateRepository,
2188    response: Response<proto::UpdateRepository>,
2189    session: MessageContext,
2190) -> Result<()> {
2191    let guest_connection_ids = session
2192        .db()
2193        .await
2194        .update_repository(&request, session.connection_id)
2195        .await?;
2196
2197    broadcast(
2198        Some(session.connection_id),
2199        guest_connection_ids.iter().copied(),
2200        |connection_id| {
2201            session
2202                .peer
2203                .forward_send(session.connection_id, connection_id, request.clone())
2204        },
2205    );
2206    response.send(proto::Ack {})?;
2207    Ok(())
2208}
2209
2210async fn remove_repository(
2211    request: proto::RemoveRepository,
2212    response: Response<proto::RemoveRepository>,
2213    session: MessageContext,
2214) -> Result<()> {
2215    let guest_connection_ids = session
2216        .db()
2217        .await
2218        .remove_repository(&request, session.connection_id)
2219        .await?;
2220
2221    broadcast(
2222        Some(session.connection_id),
2223        guest_connection_ids.iter().copied(),
2224        |connection_id| {
2225            session
2226                .peer
2227                .forward_send(session.connection_id, connection_id, request.clone())
2228        },
2229    );
2230    response.send(proto::Ack {})?;
2231    Ok(())
2232}
2233
2234/// Updates other participants with changes to the diagnostics
2235async fn update_diagnostic_summary(
2236    message: proto::UpdateDiagnosticSummary,
2237    session: MessageContext,
2238) -> Result<()> {
2239    let guest_connection_ids = session
2240        .db()
2241        .await
2242        .update_diagnostic_summary(&message, session.connection_id)
2243        .await?;
2244
2245    broadcast(
2246        Some(session.connection_id),
2247        guest_connection_ids.iter().copied(),
2248        |connection_id| {
2249            session
2250                .peer
2251                .forward_send(session.connection_id, connection_id, message.clone())
2252        },
2253    );
2254
2255    Ok(())
2256}
2257
2258/// Updates other participants with changes to the worktree settings
2259async fn update_worktree_settings(
2260    message: proto::UpdateWorktreeSettings,
2261    session: MessageContext,
2262) -> Result<()> {
2263    let guest_connection_ids = session
2264        .db()
2265        .await
2266        .update_worktree_settings(&message, session.connection_id)
2267        .await?;
2268
2269    broadcast(
2270        Some(session.connection_id),
2271        guest_connection_ids.iter().copied(),
2272        |connection_id| {
2273            session
2274                .peer
2275                .forward_send(session.connection_id, connection_id, message.clone())
2276        },
2277    );
2278
2279    Ok(())
2280}
2281
2282/// Notify other participants that a language server has started.
2283async fn start_language_server(
2284    request: proto::StartLanguageServer,
2285    session: MessageContext,
2286) -> Result<()> {
2287    let guest_connection_ids = session
2288        .db()
2289        .await
2290        .start_language_server(&request, session.connection_id)
2291        .await?;
2292
2293    broadcast(
2294        Some(session.connection_id),
2295        guest_connection_ids.iter().copied(),
2296        |connection_id| {
2297            session
2298                .peer
2299                .forward_send(session.connection_id, connection_id, request.clone())
2300        },
2301    );
2302    Ok(())
2303}
2304
2305/// Notify other participants that a language server has changed.
2306async fn update_language_server(
2307    request: proto::UpdateLanguageServer,
2308    session: MessageContext,
2309) -> Result<()> {
2310    let project_id = ProjectId::from_proto(request.project_id);
2311    let db = session.db().await;
2312
2313    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2314    {
2315        if let Some(capabilities) = update.capabilities.clone() {
2316            db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2317                .await?;
2318        }
2319    }
2320
2321    let project_connection_ids = db
2322        .project_connection_ids(project_id, session.connection_id, true)
2323        .await?;
2324    broadcast(
2325        Some(session.connection_id),
2326        project_connection_ids.iter().copied(),
2327        |connection_id| {
2328            session
2329                .peer
2330                .forward_send(session.connection_id, connection_id, request.clone())
2331        },
2332    );
2333    Ok(())
2334}
2335
2336/// forward a project request to the host. These requests should be read only
2337/// as guests are allowed to send them.
2338async fn forward_read_only_project_request<T>(
2339    request: T,
2340    response: Response<T>,
2341    session: MessageContext,
2342) -> Result<()>
2343where
2344    T: EntityMessage + RequestMessage,
2345{
2346    let project_id = ProjectId::from_proto(request.remote_entity_id());
2347    let host_connection_id = session
2348        .db()
2349        .await
2350        .host_for_read_only_project_request(project_id, session.connection_id)
2351        .await?;
2352    let payload = session.forward_request(host_connection_id, request).await?;
2353    response.send(payload)?;
2354    Ok(())
2355}
2356
2357/// forward a project request to the host. These requests are disallowed
2358/// for guests.
2359async fn forward_mutating_project_request<T>(
2360    request: T,
2361    response: Response<T>,
2362    session: MessageContext,
2363) -> Result<()>
2364where
2365    T: EntityMessage + RequestMessage,
2366{
2367    let project_id = ProjectId::from_proto(request.remote_entity_id());
2368
2369    let host_connection_id = session
2370        .db()
2371        .await
2372        .host_for_mutating_project_request(project_id, session.connection_id)
2373        .await?;
2374    let payload = session.forward_request(host_connection_id, request).await?;
2375    response.send(payload)?;
2376    Ok(())
2377}
2378
2379async fn multi_lsp_query(
2380    request: MultiLspQuery,
2381    response: Response<MultiLspQuery>,
2382    session: MessageContext,
2383) -> Result<()> {
2384    tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2385    tracing::info!("multi_lsp_query message received");
2386    forward_mutating_project_request(request, response, session).await
2387}
2388
2389/// Notify other participants that a new buffer has been created
2390async fn create_buffer_for_peer(
2391    request: proto::CreateBufferForPeer,
2392    session: MessageContext,
2393) -> Result<()> {
2394    session
2395        .db()
2396        .await
2397        .check_user_is_project_host(
2398            ProjectId::from_proto(request.project_id),
2399            session.connection_id,
2400        )
2401        .await?;
2402    let peer_id = request.peer_id.context("invalid peer id")?;
2403    session
2404        .peer
2405        .forward_send(session.connection_id, peer_id.into(), request)?;
2406    Ok(())
2407}
2408
2409/// Notify other participants that a buffer has been updated. This is
2410/// allowed for guests as long as the update is limited to selections.
2411async fn update_buffer(
2412    request: proto::UpdateBuffer,
2413    response: Response<proto::UpdateBuffer>,
2414    session: MessageContext,
2415) -> Result<()> {
2416    let project_id = ProjectId::from_proto(request.project_id);
2417    let mut capability = Capability::ReadOnly;
2418
2419    for op in request.operations.iter() {
2420        match op.variant {
2421            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2422            Some(_) => capability = Capability::ReadWrite,
2423        }
2424    }
2425
2426    let host = {
2427        let guard = session
2428            .db()
2429            .await
2430            .connections_for_buffer_update(project_id, session.connection_id, capability)
2431            .await?;
2432
2433        let (host, guests) = &*guard;
2434
2435        broadcast(
2436            Some(session.connection_id),
2437            guests.clone(),
2438            |connection_id| {
2439                session
2440                    .peer
2441                    .forward_send(session.connection_id, connection_id, request.clone())
2442            },
2443        );
2444
2445        *host
2446    };
2447
2448    if host != session.connection_id {
2449        session.forward_request(host, request.clone()).await?;
2450    }
2451
2452    response.send(proto::Ack {})?;
2453    Ok(())
2454}
2455
2456async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2457    let project_id = ProjectId::from_proto(message.project_id);
2458
2459    let operation = message.operation.as_ref().context("invalid operation")?;
2460    let capability = match operation.variant.as_ref() {
2461        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2462            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2463                match buffer_op.variant {
2464                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2465                        Capability::ReadOnly
2466                    }
2467                    _ => Capability::ReadWrite,
2468                }
2469            } else {
2470                Capability::ReadWrite
2471            }
2472        }
2473        Some(_) => Capability::ReadWrite,
2474        None => Capability::ReadOnly,
2475    };
2476
2477    let guard = session
2478        .db()
2479        .await
2480        .connections_for_buffer_update(project_id, session.connection_id, capability)
2481        .await?;
2482
2483    let (host, guests) = &*guard;
2484
2485    broadcast(
2486        Some(session.connection_id),
2487        guests.iter().chain([host]).copied(),
2488        |connection_id| {
2489            session
2490                .peer
2491                .forward_send(session.connection_id, connection_id, message.clone())
2492        },
2493    );
2494
2495    Ok(())
2496}
2497
2498/// Notify other participants that a project has been updated.
2499async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2500    request: T,
2501    session: MessageContext,
2502) -> Result<()> {
2503    let project_id = ProjectId::from_proto(request.remote_entity_id());
2504    let project_connection_ids = session
2505        .db()
2506        .await
2507        .project_connection_ids(project_id, session.connection_id, false)
2508        .await?;
2509
2510    broadcast(
2511        Some(session.connection_id),
2512        project_connection_ids.iter().copied(),
2513        |connection_id| {
2514            session
2515                .peer
2516                .forward_send(session.connection_id, connection_id, request.clone())
2517        },
2518    );
2519    Ok(())
2520}
2521
2522/// Start following another user in a call.
2523async fn follow(
2524    request: proto::Follow,
2525    response: Response<proto::Follow>,
2526    session: MessageContext,
2527) -> Result<()> {
2528    let room_id = RoomId::from_proto(request.room_id);
2529    let project_id = request.project_id.map(ProjectId::from_proto);
2530    let leader_id = request.leader_id.context("invalid leader id")?.into();
2531    let follower_id = session.connection_id;
2532
2533    session
2534        .db()
2535        .await
2536        .check_room_participants(room_id, leader_id, session.connection_id)
2537        .await?;
2538
2539    let response_payload = session.forward_request(leader_id, request).await?;
2540    response.send(response_payload)?;
2541
2542    if let Some(project_id) = project_id {
2543        let room = session
2544            .db()
2545            .await
2546            .follow(room_id, project_id, leader_id, follower_id)
2547            .await?;
2548        room_updated(&room, &session.peer);
2549    }
2550
2551    Ok(())
2552}
2553
2554/// Stop following another user in a call.
2555async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2556    let room_id = RoomId::from_proto(request.room_id);
2557    let project_id = request.project_id.map(ProjectId::from_proto);
2558    let leader_id = request.leader_id.context("invalid leader id")?.into();
2559    let follower_id = session.connection_id;
2560
2561    session
2562        .db()
2563        .await
2564        .check_room_participants(room_id, leader_id, session.connection_id)
2565        .await?;
2566
2567    session
2568        .peer
2569        .forward_send(session.connection_id, leader_id, request)?;
2570
2571    if let Some(project_id) = project_id {
2572        let room = session
2573            .db()
2574            .await
2575            .unfollow(room_id, project_id, leader_id, follower_id)
2576            .await?;
2577        room_updated(&room, &session.peer);
2578    }
2579
2580    Ok(())
2581}
2582
2583/// Notify everyone following you of your current location.
2584async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2585    let room_id = RoomId::from_proto(request.room_id);
2586    let database = session.db.lock().await;
2587
2588    let connection_ids = if let Some(project_id) = request.project_id {
2589        let project_id = ProjectId::from_proto(project_id);
2590        database
2591            .project_connection_ids(project_id, session.connection_id, true)
2592            .await?
2593    } else {
2594        database
2595            .room_connection_ids(room_id, session.connection_id)
2596            .await?
2597    };
2598
2599    // For now, don't send view update messages back to that view's current leader.
2600    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2601        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2602        _ => None,
2603    });
2604
2605    for connection_id in connection_ids.iter().cloned() {
2606        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2607            session
2608                .peer
2609                .forward_send(session.connection_id, connection_id, request.clone())?;
2610        }
2611    }
2612    Ok(())
2613}
2614
2615/// Get public data about users.
2616async fn get_users(
2617    request: proto::GetUsers,
2618    response: Response<proto::GetUsers>,
2619    session: MessageContext,
2620) -> Result<()> {
2621    let user_ids = request
2622        .user_ids
2623        .into_iter()
2624        .map(UserId::from_proto)
2625        .collect();
2626    let users = session
2627        .db()
2628        .await
2629        .get_users_by_ids(user_ids)
2630        .await?
2631        .into_iter()
2632        .map(|user| proto::User {
2633            id: user.id.to_proto(),
2634            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2635            github_login: user.github_login,
2636            name: user.name,
2637        })
2638        .collect();
2639    response.send(proto::UsersResponse { users })?;
2640    Ok(())
2641}
2642
2643/// Search for users (to invite) buy Github login
2644async fn fuzzy_search_users(
2645    request: proto::FuzzySearchUsers,
2646    response: Response<proto::FuzzySearchUsers>,
2647    session: MessageContext,
2648) -> Result<()> {
2649    let query = request.query;
2650    let users = match query.len() {
2651        0 => vec![],
2652        1 | 2 => session
2653            .db()
2654            .await
2655            .get_user_by_github_login(&query)
2656            .await?
2657            .into_iter()
2658            .collect(),
2659        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2660    };
2661    let users = users
2662        .into_iter()
2663        .filter(|user| user.id != session.user_id())
2664        .map(|user| proto::User {
2665            id: user.id.to_proto(),
2666            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2667            github_login: user.github_login,
2668            name: user.name,
2669        })
2670        .collect();
2671    response.send(proto::UsersResponse { users })?;
2672    Ok(())
2673}
2674
2675/// Send a contact request to another user.
2676async fn request_contact(
2677    request: proto::RequestContact,
2678    response: Response<proto::RequestContact>,
2679    session: MessageContext,
2680) -> Result<()> {
2681    let requester_id = session.user_id();
2682    let responder_id = UserId::from_proto(request.responder_id);
2683    if requester_id == responder_id {
2684        return Err(anyhow!("cannot add yourself as a contact"))?;
2685    }
2686
2687    let notifications = session
2688        .db()
2689        .await
2690        .send_contact_request(requester_id, responder_id)
2691        .await?;
2692
2693    // Update outgoing contact requests of requester
2694    let mut update = proto::UpdateContacts::default();
2695    update.outgoing_requests.push(responder_id.to_proto());
2696    for connection_id in session
2697        .connection_pool()
2698        .await
2699        .user_connection_ids(requester_id)
2700    {
2701        session.peer.send(connection_id, update.clone())?;
2702    }
2703
2704    // Update incoming contact requests of responder
2705    let mut update = proto::UpdateContacts::default();
2706    update
2707        .incoming_requests
2708        .push(proto::IncomingContactRequest {
2709            requester_id: requester_id.to_proto(),
2710        });
2711    let connection_pool = session.connection_pool().await;
2712    for connection_id in connection_pool.user_connection_ids(responder_id) {
2713        session.peer.send(connection_id, update.clone())?;
2714    }
2715
2716    send_notifications(&connection_pool, &session.peer, notifications);
2717
2718    response.send(proto::Ack {})?;
2719    Ok(())
2720}
2721
2722/// Accept or decline a contact request
2723async fn respond_to_contact_request(
2724    request: proto::RespondToContactRequest,
2725    response: Response<proto::RespondToContactRequest>,
2726    session: MessageContext,
2727) -> Result<()> {
2728    let responder_id = session.user_id();
2729    let requester_id = UserId::from_proto(request.requester_id);
2730    let db = session.db().await;
2731    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2732        db.dismiss_contact_notification(responder_id, requester_id)
2733            .await?;
2734    } else {
2735        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2736
2737        let notifications = db
2738            .respond_to_contact_request(responder_id, requester_id, accept)
2739            .await?;
2740        let requester_busy = db.is_user_busy(requester_id).await?;
2741        let responder_busy = db.is_user_busy(responder_id).await?;
2742
2743        let pool = session.connection_pool().await;
2744        // Update responder with new contact
2745        let mut update = proto::UpdateContacts::default();
2746        if accept {
2747            update
2748                .contacts
2749                .push(contact_for_user(requester_id, requester_busy, &pool));
2750        }
2751        update
2752            .remove_incoming_requests
2753            .push(requester_id.to_proto());
2754        for connection_id in pool.user_connection_ids(responder_id) {
2755            session.peer.send(connection_id, update.clone())?;
2756        }
2757
2758        // Update requester with new contact
2759        let mut update = proto::UpdateContacts::default();
2760        if accept {
2761            update
2762                .contacts
2763                .push(contact_for_user(responder_id, responder_busy, &pool));
2764        }
2765        update
2766            .remove_outgoing_requests
2767            .push(responder_id.to_proto());
2768
2769        for connection_id in pool.user_connection_ids(requester_id) {
2770            session.peer.send(connection_id, update.clone())?;
2771        }
2772
2773        send_notifications(&pool, &session.peer, notifications);
2774    }
2775
2776    response.send(proto::Ack {})?;
2777    Ok(())
2778}
2779
2780/// Remove a contact.
2781async fn remove_contact(
2782    request: proto::RemoveContact,
2783    response: Response<proto::RemoveContact>,
2784    session: MessageContext,
2785) -> Result<()> {
2786    let requester_id = session.user_id();
2787    let responder_id = UserId::from_proto(request.user_id);
2788    let db = session.db().await;
2789    let (contact_accepted, deleted_notification_id) =
2790        db.remove_contact(requester_id, responder_id).await?;
2791
2792    let pool = session.connection_pool().await;
2793    // Update outgoing contact requests of requester
2794    let mut update = proto::UpdateContacts::default();
2795    if contact_accepted {
2796        update.remove_contacts.push(responder_id.to_proto());
2797    } else {
2798        update
2799            .remove_outgoing_requests
2800            .push(responder_id.to_proto());
2801    }
2802    for connection_id in pool.user_connection_ids(requester_id) {
2803        session.peer.send(connection_id, update.clone())?;
2804    }
2805
2806    // Update incoming contact requests of responder
2807    let mut update = proto::UpdateContacts::default();
2808    if contact_accepted {
2809        update.remove_contacts.push(requester_id.to_proto());
2810    } else {
2811        update
2812            .remove_incoming_requests
2813            .push(requester_id.to_proto());
2814    }
2815    for connection_id in pool.user_connection_ids(responder_id) {
2816        session.peer.send(connection_id, update.clone())?;
2817        if let Some(notification_id) = deleted_notification_id {
2818            session.peer.send(
2819                connection_id,
2820                proto::DeleteNotification {
2821                    notification_id: notification_id.to_proto(),
2822                },
2823            )?;
2824        }
2825    }
2826
2827    response.send(proto::Ack {})?;
2828    Ok(())
2829}
2830
2831fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2832    version.0.minor() < 139
2833}
2834
2835async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2836    if is_staff {
2837        return Ok(proto::Plan::ZedPro);
2838    }
2839
2840    let subscription = db.get_active_billing_subscription(user_id).await?;
2841    let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2842
2843    let plan = if let Some(subscription_kind) = subscription_kind {
2844        match subscription_kind {
2845            SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2846            SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2847            SubscriptionKind::ZedFree => proto::Plan::Free,
2848        }
2849    } else {
2850        proto::Plan::Free
2851    };
2852
2853    Ok(plan)
2854}
2855
2856async fn make_update_user_plan_message(
2857    user: &User,
2858    is_staff: bool,
2859    db: &Arc<Database>,
2860    llm_db: Option<Arc<LlmDatabase>>,
2861) -> Result<proto::UpdateUserPlan> {
2862    let feature_flags = db.get_user_flags(user.id).await?;
2863    let plan = current_plan(db, user.id, is_staff).await?;
2864    let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2865    let billing_preferences = db.get_billing_preferences(user.id).await?;
2866
2867    let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2868        let subscription = db.get_active_billing_subscription(user.id).await?;
2869
2870        let subscription_period =
2871            crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2872
2873        let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2874            llm_db
2875                .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2876                .await?
2877        } else {
2878            None
2879        };
2880
2881        (subscription_period, usage)
2882    } else {
2883        (None, None)
2884    };
2885
2886    let bypass_account_age_check = feature_flags
2887        .iter()
2888        .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2889    let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2890        && !bypass_account_age_check
2891        && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2892
2893    Ok(proto::UpdateUserPlan {
2894        plan: plan.into(),
2895        trial_started_at: billing_customer
2896            .as_ref()
2897            .and_then(|billing_customer| billing_customer.trial_started_at)
2898            .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2899        is_usage_based_billing_enabled: if is_staff {
2900            Some(true)
2901        } else {
2902            billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2903        },
2904        subscription_period: subscription_period.map(|(started_at, ended_at)| {
2905            proto::SubscriptionPeriod {
2906                started_at: started_at.timestamp() as u64,
2907                ended_at: ended_at.timestamp() as u64,
2908            }
2909        }),
2910        account_too_young: Some(account_too_young),
2911        has_overdue_invoices: billing_customer
2912            .map(|billing_customer| billing_customer.has_overdue_invoices),
2913        usage: Some(
2914            usage
2915                .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2916                .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2917        ),
2918    })
2919}
2920
2921fn model_requests_limit(
2922    plan: cloud_llm_client::Plan,
2923    feature_flags: &Vec<String>,
2924) -> cloud_llm_client::UsageLimit {
2925    match plan.model_requests_limit() {
2926        cloud_llm_client::UsageLimit::Limited(limit) => {
2927            let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2928                && feature_flags
2929                    .iter()
2930                    .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2931            {
2932                1_000
2933            } else {
2934                limit
2935            };
2936
2937            cloud_llm_client::UsageLimit::Limited(limit)
2938        }
2939        cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2940    }
2941}
2942
2943fn subscription_usage_to_proto(
2944    plan: proto::Plan,
2945    usage: crate::llm::db::subscription_usage::Model,
2946    feature_flags: &Vec<String>,
2947) -> proto::SubscriptionUsage {
2948    let plan = match plan {
2949        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2950        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2951        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2952    };
2953
2954    proto::SubscriptionUsage {
2955        model_requests_usage_amount: usage.model_requests as u32,
2956        model_requests_usage_limit: Some(proto::UsageLimit {
2957            variant: Some(match model_requests_limit(plan, feature_flags) {
2958                cloud_llm_client::UsageLimit::Limited(limit) => {
2959                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2960                        limit: limit as u32,
2961                    })
2962                }
2963                cloud_llm_client::UsageLimit::Unlimited => {
2964                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2965                }
2966            }),
2967        }),
2968        edit_predictions_usage_amount: usage.edit_predictions as u32,
2969        edit_predictions_usage_limit: Some(proto::UsageLimit {
2970            variant: Some(match plan.edit_predictions_limit() {
2971                cloud_llm_client::UsageLimit::Limited(limit) => {
2972                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2973                        limit: limit as u32,
2974                    })
2975                }
2976                cloud_llm_client::UsageLimit::Unlimited => {
2977                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2978                }
2979            }),
2980        }),
2981    }
2982}
2983
2984fn make_default_subscription_usage(
2985    plan: proto::Plan,
2986    feature_flags: &Vec<String>,
2987) -> proto::SubscriptionUsage {
2988    let plan = match plan {
2989        proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2990        proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2991        proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2992    };
2993
2994    proto::SubscriptionUsage {
2995        model_requests_usage_amount: 0,
2996        model_requests_usage_limit: Some(proto::UsageLimit {
2997            variant: Some(match model_requests_limit(plan, feature_flags) {
2998                cloud_llm_client::UsageLimit::Limited(limit) => {
2999                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3000                        limit: limit as u32,
3001                    })
3002                }
3003                cloud_llm_client::UsageLimit::Unlimited => {
3004                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3005                }
3006            }),
3007        }),
3008        edit_predictions_usage_amount: 0,
3009        edit_predictions_usage_limit: Some(proto::UsageLimit {
3010            variant: Some(match plan.edit_predictions_limit() {
3011                cloud_llm_client::UsageLimit::Limited(limit) => {
3012                    proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3013                        limit: limit as u32,
3014                    })
3015                }
3016                cloud_llm_client::UsageLimit::Unlimited => {
3017                    proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3018                }
3019            }),
3020        }),
3021    }
3022}
3023
3024async fn update_user_plan(session: &Session) -> Result<()> {
3025    let db = session.db().await;
3026
3027    let update_user_plan = make_update_user_plan_message(
3028        session.principal.user(),
3029        session.is_staff(),
3030        &db.0,
3031        session.app_state.llm_db.clone(),
3032    )
3033    .await?;
3034
3035    session
3036        .peer
3037        .send(session.connection_id, update_user_plan)
3038        .trace_err();
3039
3040    Ok(())
3041}
3042
3043async fn subscribe_to_channels(
3044    _: proto::SubscribeToChannels,
3045    session: MessageContext,
3046) -> Result<()> {
3047    subscribe_user_to_channels(session.user_id(), &session).await?;
3048    Ok(())
3049}
3050
3051async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3052    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3053    let mut pool = session.connection_pool().await;
3054    for membership in &channels_for_user.channel_memberships {
3055        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3056    }
3057    session.peer.send(
3058        session.connection_id,
3059        build_update_user_channels(&channels_for_user),
3060    )?;
3061    session.peer.send(
3062        session.connection_id,
3063        build_channels_update(channels_for_user),
3064    )?;
3065    Ok(())
3066}
3067
3068/// Creates a new channel.
3069async fn create_channel(
3070    request: proto::CreateChannel,
3071    response: Response<proto::CreateChannel>,
3072    session: MessageContext,
3073) -> Result<()> {
3074    let db = session.db().await;
3075
3076    let parent_id = request.parent_id.map(ChannelId::from_proto);
3077    let (channel, membership) = db
3078        .create_channel(&request.name, parent_id, session.user_id())
3079        .await?;
3080
3081    let root_id = channel.root_id();
3082    let channel = Channel::from_model(channel);
3083
3084    response.send(proto::CreateChannelResponse {
3085        channel: Some(channel.to_proto()),
3086        parent_id: request.parent_id,
3087    })?;
3088
3089    let mut connection_pool = session.connection_pool().await;
3090    if let Some(membership) = membership {
3091        connection_pool.subscribe_to_channel(
3092            membership.user_id,
3093            membership.channel_id,
3094            membership.role,
3095        );
3096        let update = proto::UpdateUserChannels {
3097            channel_memberships: vec![proto::ChannelMembership {
3098                channel_id: membership.channel_id.to_proto(),
3099                role: membership.role.into(),
3100            }],
3101            ..Default::default()
3102        };
3103        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3104            session.peer.send(connection_id, update.clone())?;
3105        }
3106    }
3107
3108    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3109        if !role.can_see_channel(channel.visibility) {
3110            continue;
3111        }
3112
3113        let update = proto::UpdateChannels {
3114            channels: vec![channel.to_proto()],
3115            ..Default::default()
3116        };
3117        session.peer.send(connection_id, update.clone())?;
3118    }
3119
3120    Ok(())
3121}
3122
3123/// Delete a channel
3124async fn delete_channel(
3125    request: proto::DeleteChannel,
3126    response: Response<proto::DeleteChannel>,
3127    session: MessageContext,
3128) -> Result<()> {
3129    let db = session.db().await;
3130
3131    let channel_id = request.channel_id;
3132    let (root_channel, removed_channels) = db
3133        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3134        .await?;
3135    response.send(proto::Ack {})?;
3136
3137    // Notify members of removed channels
3138    let mut update = proto::UpdateChannels::default();
3139    update
3140        .delete_channels
3141        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3142
3143    let connection_pool = session.connection_pool().await;
3144    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3145        session.peer.send(connection_id, update.clone())?;
3146    }
3147
3148    Ok(())
3149}
3150
3151/// Invite someone to join a channel.
3152async fn invite_channel_member(
3153    request: proto::InviteChannelMember,
3154    response: Response<proto::InviteChannelMember>,
3155    session: MessageContext,
3156) -> Result<()> {
3157    let db = session.db().await;
3158    let channel_id = ChannelId::from_proto(request.channel_id);
3159    let invitee_id = UserId::from_proto(request.user_id);
3160    let InviteMemberResult {
3161        channel,
3162        notifications,
3163    } = db
3164        .invite_channel_member(
3165            channel_id,
3166            invitee_id,
3167            session.user_id(),
3168            request.role().into(),
3169        )
3170        .await?;
3171
3172    let update = proto::UpdateChannels {
3173        channel_invitations: vec![channel.to_proto()],
3174        ..Default::default()
3175    };
3176
3177    let connection_pool = session.connection_pool().await;
3178    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3179        session.peer.send(connection_id, update.clone())?;
3180    }
3181
3182    send_notifications(&connection_pool, &session.peer, notifications);
3183
3184    response.send(proto::Ack {})?;
3185    Ok(())
3186}
3187
3188/// remove someone from a channel
3189async fn remove_channel_member(
3190    request: proto::RemoveChannelMember,
3191    response: Response<proto::RemoveChannelMember>,
3192    session: MessageContext,
3193) -> Result<()> {
3194    let db = session.db().await;
3195    let channel_id = ChannelId::from_proto(request.channel_id);
3196    let member_id = UserId::from_proto(request.user_id);
3197
3198    let RemoveChannelMemberResult {
3199        membership_update,
3200        notification_id,
3201    } = db
3202        .remove_channel_member(channel_id, member_id, session.user_id())
3203        .await?;
3204
3205    let mut connection_pool = session.connection_pool().await;
3206    notify_membership_updated(
3207        &mut connection_pool,
3208        membership_update,
3209        member_id,
3210        &session.peer,
3211    );
3212    for connection_id in connection_pool.user_connection_ids(member_id) {
3213        if let Some(notification_id) = notification_id {
3214            session
3215                .peer
3216                .send(
3217                    connection_id,
3218                    proto::DeleteNotification {
3219                        notification_id: notification_id.to_proto(),
3220                    },
3221                )
3222                .trace_err();
3223        }
3224    }
3225
3226    response.send(proto::Ack {})?;
3227    Ok(())
3228}
3229
3230/// Toggle the channel between public and private.
3231/// Care is taken to maintain the invariant that public channels only descend from public channels,
3232/// (though members-only channels can appear at any point in the hierarchy).
3233async fn set_channel_visibility(
3234    request: proto::SetChannelVisibility,
3235    response: Response<proto::SetChannelVisibility>,
3236    session: MessageContext,
3237) -> Result<()> {
3238    let db = session.db().await;
3239    let channel_id = ChannelId::from_proto(request.channel_id);
3240    let visibility = request.visibility().into();
3241
3242    let channel_model = db
3243        .set_channel_visibility(channel_id, visibility, session.user_id())
3244        .await?;
3245    let root_id = channel_model.root_id();
3246    let channel = Channel::from_model(channel_model);
3247
3248    let mut connection_pool = session.connection_pool().await;
3249    for (user_id, role) in connection_pool
3250        .channel_user_ids(root_id)
3251        .collect::<Vec<_>>()
3252        .into_iter()
3253    {
3254        let update = if role.can_see_channel(channel.visibility) {
3255            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3256            proto::UpdateChannels {
3257                channels: vec![channel.to_proto()],
3258                ..Default::default()
3259            }
3260        } else {
3261            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3262            proto::UpdateChannels {
3263                delete_channels: vec![channel.id.to_proto()],
3264                ..Default::default()
3265            }
3266        };
3267
3268        for connection_id in connection_pool.user_connection_ids(user_id) {
3269            session.peer.send(connection_id, update.clone())?;
3270        }
3271    }
3272
3273    response.send(proto::Ack {})?;
3274    Ok(())
3275}
3276
3277/// Alter the role for a user in the channel.
3278async fn set_channel_member_role(
3279    request: proto::SetChannelMemberRole,
3280    response: Response<proto::SetChannelMemberRole>,
3281    session: MessageContext,
3282) -> Result<()> {
3283    let db = session.db().await;
3284    let channel_id = ChannelId::from_proto(request.channel_id);
3285    let member_id = UserId::from_proto(request.user_id);
3286    let result = db
3287        .set_channel_member_role(
3288            channel_id,
3289            session.user_id(),
3290            member_id,
3291            request.role().into(),
3292        )
3293        .await?;
3294
3295    match result {
3296        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3297            let mut connection_pool = session.connection_pool().await;
3298            notify_membership_updated(
3299                &mut connection_pool,
3300                membership_update,
3301                member_id,
3302                &session.peer,
3303            )
3304        }
3305        db::SetMemberRoleResult::InviteUpdated(channel) => {
3306            let update = proto::UpdateChannels {
3307                channel_invitations: vec![channel.to_proto()],
3308                ..Default::default()
3309            };
3310
3311            for connection_id in session
3312                .connection_pool()
3313                .await
3314                .user_connection_ids(member_id)
3315            {
3316                session.peer.send(connection_id, update.clone())?;
3317            }
3318        }
3319    }
3320
3321    response.send(proto::Ack {})?;
3322    Ok(())
3323}
3324
3325/// Change the name of a channel
3326async fn rename_channel(
3327    request: proto::RenameChannel,
3328    response: Response<proto::RenameChannel>,
3329    session: MessageContext,
3330) -> Result<()> {
3331    let db = session.db().await;
3332    let channel_id = ChannelId::from_proto(request.channel_id);
3333    let channel_model = db
3334        .rename_channel(channel_id, session.user_id(), &request.name)
3335        .await?;
3336    let root_id = channel_model.root_id();
3337    let channel = Channel::from_model(channel_model);
3338
3339    response.send(proto::RenameChannelResponse {
3340        channel: Some(channel.to_proto()),
3341    })?;
3342
3343    let connection_pool = session.connection_pool().await;
3344    let update = proto::UpdateChannels {
3345        channels: vec![channel.to_proto()],
3346        ..Default::default()
3347    };
3348    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3349        if role.can_see_channel(channel.visibility) {
3350            session.peer.send(connection_id, update.clone())?;
3351        }
3352    }
3353
3354    Ok(())
3355}
3356
3357/// Move a channel to a new parent.
3358async fn move_channel(
3359    request: proto::MoveChannel,
3360    response: Response<proto::MoveChannel>,
3361    session: MessageContext,
3362) -> Result<()> {
3363    let channel_id = ChannelId::from_proto(request.channel_id);
3364    let to = ChannelId::from_proto(request.to);
3365
3366    let (root_id, channels) = session
3367        .db()
3368        .await
3369        .move_channel(channel_id, to, session.user_id())
3370        .await?;
3371
3372    let connection_pool = session.connection_pool().await;
3373    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3374        let channels = channels
3375            .iter()
3376            .filter_map(|channel| {
3377                if role.can_see_channel(channel.visibility) {
3378                    Some(channel.to_proto())
3379                } else {
3380                    None
3381                }
3382            })
3383            .collect::<Vec<_>>();
3384        if channels.is_empty() {
3385            continue;
3386        }
3387
3388        let update = proto::UpdateChannels {
3389            channels,
3390            ..Default::default()
3391        };
3392
3393        session.peer.send(connection_id, update.clone())?;
3394    }
3395
3396    response.send(Ack {})?;
3397    Ok(())
3398}
3399
3400async fn reorder_channel(
3401    request: proto::ReorderChannel,
3402    response: Response<proto::ReorderChannel>,
3403    session: MessageContext,
3404) -> Result<()> {
3405    let channel_id = ChannelId::from_proto(request.channel_id);
3406    let direction = request.direction();
3407
3408    let updated_channels = session
3409        .db()
3410        .await
3411        .reorder_channel(channel_id, direction, session.user_id())
3412        .await?;
3413
3414    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3415        let connection_pool = session.connection_pool().await;
3416        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3417            let channels = updated_channels
3418                .iter()
3419                .filter_map(|channel| {
3420                    if role.can_see_channel(channel.visibility) {
3421                        Some(channel.to_proto())
3422                    } else {
3423                        None
3424                    }
3425                })
3426                .collect::<Vec<_>>();
3427
3428            if channels.is_empty() {
3429                continue;
3430            }
3431
3432            let update = proto::UpdateChannels {
3433                channels,
3434                ..Default::default()
3435            };
3436
3437            session.peer.send(connection_id, update.clone())?;
3438        }
3439    }
3440
3441    response.send(Ack {})?;
3442    Ok(())
3443}
3444
3445/// Get the list of channel members
3446async fn get_channel_members(
3447    request: proto::GetChannelMembers,
3448    response: Response<proto::GetChannelMembers>,
3449    session: MessageContext,
3450) -> Result<()> {
3451    let db = session.db().await;
3452    let channel_id = ChannelId::from_proto(request.channel_id);
3453    let limit = if request.limit == 0 {
3454        u16::MAX as u64
3455    } else {
3456        request.limit
3457    };
3458    let (members, users) = db
3459        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3460        .await?;
3461    response.send(proto::GetChannelMembersResponse { members, users })?;
3462    Ok(())
3463}
3464
3465/// Accept or decline a channel invitation.
3466async fn respond_to_channel_invite(
3467    request: proto::RespondToChannelInvite,
3468    response: Response<proto::RespondToChannelInvite>,
3469    session: MessageContext,
3470) -> Result<()> {
3471    let db = session.db().await;
3472    let channel_id = ChannelId::from_proto(request.channel_id);
3473    let RespondToChannelInvite {
3474        membership_update,
3475        notifications,
3476    } = db
3477        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3478        .await?;
3479
3480    let mut connection_pool = session.connection_pool().await;
3481    if let Some(membership_update) = membership_update {
3482        notify_membership_updated(
3483            &mut connection_pool,
3484            membership_update,
3485            session.user_id(),
3486            &session.peer,
3487        );
3488    } else {
3489        let update = proto::UpdateChannels {
3490            remove_channel_invitations: vec![channel_id.to_proto()],
3491            ..Default::default()
3492        };
3493
3494        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3495            session.peer.send(connection_id, update.clone())?;
3496        }
3497    };
3498
3499    send_notifications(&connection_pool, &session.peer, notifications);
3500
3501    response.send(proto::Ack {})?;
3502
3503    Ok(())
3504}
3505
3506/// Join the channels' room
3507async fn join_channel(
3508    request: proto::JoinChannel,
3509    response: Response<proto::JoinChannel>,
3510    session: MessageContext,
3511) -> Result<()> {
3512    let channel_id = ChannelId::from_proto(request.channel_id);
3513    join_channel_internal(channel_id, Box::new(response), session).await
3514}
3515
3516trait JoinChannelInternalResponse {
3517    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3518}
3519impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3520    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3521        Response::<proto::JoinChannel>::send(self, result)
3522    }
3523}
3524impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3525    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3526        Response::<proto::JoinRoom>::send(self, result)
3527    }
3528}
3529
3530async fn join_channel_internal(
3531    channel_id: ChannelId,
3532    response: Box<impl JoinChannelInternalResponse>,
3533    session: MessageContext,
3534) -> Result<()> {
3535    let joined_room = {
3536        let mut db = session.db().await;
3537        // If zed quits without leaving the room, and the user re-opens zed before the
3538        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3539        // room they were in.
3540        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3541            tracing::info!(
3542                stale_connection_id = %connection,
3543                "cleaning up stale connection",
3544            );
3545            drop(db);
3546            leave_room_for_session(&session, connection).await?;
3547            db = session.db().await;
3548        }
3549
3550        let (joined_room, membership_updated, role) = db
3551            .join_channel(channel_id, session.user_id(), session.connection_id)
3552            .await?;
3553
3554        let live_kit_connection_info =
3555            session
3556                .app_state
3557                .livekit_client
3558                .as_ref()
3559                .and_then(|live_kit| {
3560                    let (can_publish, token) = if role == ChannelRole::Guest {
3561                        (
3562                            false,
3563                            live_kit
3564                                .guest_token(
3565                                    &joined_room.room.livekit_room,
3566                                    &session.user_id().to_string(),
3567                                )
3568                                .trace_err()?,
3569                        )
3570                    } else {
3571                        (
3572                            true,
3573                            live_kit
3574                                .room_token(
3575                                    &joined_room.room.livekit_room,
3576                                    &session.user_id().to_string(),
3577                                )
3578                                .trace_err()?,
3579                        )
3580                    };
3581
3582                    Some(LiveKitConnectionInfo {
3583                        server_url: live_kit.url().into(),
3584                        token,
3585                        can_publish,
3586                    })
3587                });
3588
3589        response.send(proto::JoinRoomResponse {
3590            room: Some(joined_room.room.clone()),
3591            channel_id: joined_room
3592                .channel
3593                .as_ref()
3594                .map(|channel| channel.id.to_proto()),
3595            live_kit_connection_info,
3596        })?;
3597
3598        let mut connection_pool = session.connection_pool().await;
3599        if let Some(membership_updated) = membership_updated {
3600            notify_membership_updated(
3601                &mut connection_pool,
3602                membership_updated,
3603                session.user_id(),
3604                &session.peer,
3605            );
3606        }
3607
3608        room_updated(&joined_room.room, &session.peer);
3609
3610        joined_room
3611    };
3612
3613    channel_updated(
3614        &joined_room.channel.context("channel not returned")?,
3615        &joined_room.room,
3616        &session.peer,
3617        &*session.connection_pool().await,
3618    );
3619
3620    update_user_contacts(session.user_id(), &session).await?;
3621    Ok(())
3622}
3623
3624/// Start editing the channel notes
3625async fn join_channel_buffer(
3626    request: proto::JoinChannelBuffer,
3627    response: Response<proto::JoinChannelBuffer>,
3628    session: MessageContext,
3629) -> Result<()> {
3630    let db = session.db().await;
3631    let channel_id = ChannelId::from_proto(request.channel_id);
3632
3633    let open_response = db
3634        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3635        .await?;
3636
3637    let collaborators = open_response.collaborators.clone();
3638    response.send(open_response)?;
3639
3640    let update = UpdateChannelBufferCollaborators {
3641        channel_id: channel_id.to_proto(),
3642        collaborators: collaborators.clone(),
3643    };
3644    channel_buffer_updated(
3645        session.connection_id,
3646        collaborators
3647            .iter()
3648            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3649        &update,
3650        &session.peer,
3651    );
3652
3653    Ok(())
3654}
3655
3656/// Edit the channel notes
3657async fn update_channel_buffer(
3658    request: proto::UpdateChannelBuffer,
3659    session: MessageContext,
3660) -> Result<()> {
3661    let db = session.db().await;
3662    let channel_id = ChannelId::from_proto(request.channel_id);
3663
3664    let (collaborators, epoch, version) = db
3665        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3666        .await?;
3667
3668    channel_buffer_updated(
3669        session.connection_id,
3670        collaborators.clone(),
3671        &proto::UpdateChannelBuffer {
3672            channel_id: channel_id.to_proto(),
3673            operations: request.operations,
3674        },
3675        &session.peer,
3676    );
3677
3678    let pool = &*session.connection_pool().await;
3679
3680    let non_collaborators =
3681        pool.channel_connection_ids(channel_id)
3682            .filter_map(|(connection_id, _)| {
3683                if collaborators.contains(&connection_id) {
3684                    None
3685                } else {
3686                    Some(connection_id)
3687                }
3688            });
3689
3690    broadcast(None, non_collaborators, |peer_id| {
3691        session.peer.send(
3692            peer_id,
3693            proto::UpdateChannels {
3694                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3695                    channel_id: channel_id.to_proto(),
3696                    epoch: epoch as u64,
3697                    version: version.clone(),
3698                }],
3699                ..Default::default()
3700            },
3701        )
3702    });
3703
3704    Ok(())
3705}
3706
3707/// Rejoin the channel notes after a connection blip
3708async fn rejoin_channel_buffers(
3709    request: proto::RejoinChannelBuffers,
3710    response: Response<proto::RejoinChannelBuffers>,
3711    session: MessageContext,
3712) -> Result<()> {
3713    let db = session.db().await;
3714    let buffers = db
3715        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3716        .await?;
3717
3718    for rejoined_buffer in &buffers {
3719        let collaborators_to_notify = rejoined_buffer
3720            .buffer
3721            .collaborators
3722            .iter()
3723            .filter_map(|c| Some(c.peer_id?.into()));
3724        channel_buffer_updated(
3725            session.connection_id,
3726            collaborators_to_notify,
3727            &proto::UpdateChannelBufferCollaborators {
3728                channel_id: rejoined_buffer.buffer.channel_id,
3729                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3730            },
3731            &session.peer,
3732        );
3733    }
3734
3735    response.send(proto::RejoinChannelBuffersResponse {
3736        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3737    })?;
3738
3739    Ok(())
3740}
3741
3742/// Stop editing the channel notes
3743async fn leave_channel_buffer(
3744    request: proto::LeaveChannelBuffer,
3745    response: Response<proto::LeaveChannelBuffer>,
3746    session: MessageContext,
3747) -> Result<()> {
3748    let db = session.db().await;
3749    let channel_id = ChannelId::from_proto(request.channel_id);
3750
3751    let left_buffer = db
3752        .leave_channel_buffer(channel_id, session.connection_id)
3753        .await?;
3754
3755    response.send(Ack {})?;
3756
3757    channel_buffer_updated(
3758        session.connection_id,
3759        left_buffer.connections,
3760        &proto::UpdateChannelBufferCollaborators {
3761            channel_id: channel_id.to_proto(),
3762            collaborators: left_buffer.collaborators,
3763        },
3764        &session.peer,
3765    );
3766
3767    Ok(())
3768}
3769
3770fn channel_buffer_updated<T: EnvelopedMessage>(
3771    sender_id: ConnectionId,
3772    collaborators: impl IntoIterator<Item = ConnectionId>,
3773    message: &T,
3774    peer: &Peer,
3775) {
3776    broadcast(Some(sender_id), collaborators, |peer_id| {
3777        peer.send(peer_id, message.clone())
3778    });
3779}
3780
3781fn send_notifications(
3782    connection_pool: &ConnectionPool,
3783    peer: &Peer,
3784    notifications: db::NotificationBatch,
3785) {
3786    for (user_id, notification) in notifications {
3787        for connection_id in connection_pool.user_connection_ids(user_id) {
3788            if let Err(error) = peer.send(
3789                connection_id,
3790                proto::AddNotification {
3791                    notification: Some(notification.clone()),
3792                },
3793            ) {
3794                tracing::error!(
3795                    "failed to send notification to {:?} {}",
3796                    connection_id,
3797                    error
3798                );
3799            }
3800        }
3801    }
3802}
3803
3804/// Send a message to the channel
3805async fn send_channel_message(
3806    request: proto::SendChannelMessage,
3807    response: Response<proto::SendChannelMessage>,
3808    session: MessageContext,
3809) -> Result<()> {
3810    // Validate the message body.
3811    let body = request.body.trim().to_string();
3812    if body.len() > MAX_MESSAGE_LEN {
3813        return Err(anyhow!("message is too long"))?;
3814    }
3815    if body.is_empty() {
3816        return Err(anyhow!("message can't be blank"))?;
3817    }
3818
3819    // TODO: adjust mentions if body is trimmed
3820
3821    let timestamp = OffsetDateTime::now_utc();
3822    let nonce = request.nonce.context("nonce can't be blank")?;
3823
3824    let channel_id = ChannelId::from_proto(request.channel_id);
3825    let CreatedChannelMessage {
3826        message_id,
3827        participant_connection_ids,
3828        notifications,
3829    } = session
3830        .db()
3831        .await
3832        .create_channel_message(
3833            channel_id,
3834            session.user_id(),
3835            &body,
3836            &request.mentions,
3837            timestamp,
3838            nonce.clone().into(),
3839            request.reply_to_message_id.map(MessageId::from_proto),
3840        )
3841        .await?;
3842
3843    let message = proto::ChannelMessage {
3844        sender_id: session.user_id().to_proto(),
3845        id: message_id.to_proto(),
3846        body,
3847        mentions: request.mentions,
3848        timestamp: timestamp.unix_timestamp() as u64,
3849        nonce: Some(nonce),
3850        reply_to_message_id: request.reply_to_message_id,
3851        edited_at: None,
3852    };
3853    broadcast(
3854        Some(session.connection_id),
3855        participant_connection_ids.clone(),
3856        |connection| {
3857            session.peer.send(
3858                connection,
3859                proto::ChannelMessageSent {
3860                    channel_id: channel_id.to_proto(),
3861                    message: Some(message.clone()),
3862                },
3863            )
3864        },
3865    );
3866    response.send(proto::SendChannelMessageResponse {
3867        message: Some(message),
3868    })?;
3869
3870    let pool = &*session.connection_pool().await;
3871    let non_participants =
3872        pool.channel_connection_ids(channel_id)
3873            .filter_map(|(connection_id, _)| {
3874                if participant_connection_ids.contains(&connection_id) {
3875                    None
3876                } else {
3877                    Some(connection_id)
3878                }
3879            });
3880    broadcast(None, non_participants, |peer_id| {
3881        session.peer.send(
3882            peer_id,
3883            proto::UpdateChannels {
3884                latest_channel_message_ids: vec![proto::ChannelMessageId {
3885                    channel_id: channel_id.to_proto(),
3886                    message_id: message_id.to_proto(),
3887                }],
3888                ..Default::default()
3889            },
3890        )
3891    });
3892    send_notifications(pool, &session.peer, notifications);
3893
3894    Ok(())
3895}
3896
3897/// Delete a channel message
3898async fn remove_channel_message(
3899    request: proto::RemoveChannelMessage,
3900    response: Response<proto::RemoveChannelMessage>,
3901    session: MessageContext,
3902) -> Result<()> {
3903    let channel_id = ChannelId::from_proto(request.channel_id);
3904    let message_id = MessageId::from_proto(request.message_id);
3905    let (connection_ids, existing_notification_ids) = session
3906        .db()
3907        .await
3908        .remove_channel_message(channel_id, message_id, session.user_id())
3909        .await?;
3910
3911    broadcast(
3912        Some(session.connection_id),
3913        connection_ids,
3914        move |connection| {
3915            session.peer.send(connection, request.clone())?;
3916
3917            for notification_id in &existing_notification_ids {
3918                session.peer.send(
3919                    connection,
3920                    proto::DeleteNotification {
3921                        notification_id: (*notification_id).to_proto(),
3922                    },
3923                )?;
3924            }
3925
3926            Ok(())
3927        },
3928    );
3929    response.send(proto::Ack {})?;
3930    Ok(())
3931}
3932
3933async fn update_channel_message(
3934    request: proto::UpdateChannelMessage,
3935    response: Response<proto::UpdateChannelMessage>,
3936    session: MessageContext,
3937) -> Result<()> {
3938    let channel_id = ChannelId::from_proto(request.channel_id);
3939    let message_id = MessageId::from_proto(request.message_id);
3940    let updated_at = OffsetDateTime::now_utc();
3941    let UpdatedChannelMessage {
3942        message_id,
3943        participant_connection_ids,
3944        notifications,
3945        reply_to_message_id,
3946        timestamp,
3947        deleted_mention_notification_ids,
3948        updated_mention_notifications,
3949    } = session
3950        .db()
3951        .await
3952        .update_channel_message(
3953            channel_id,
3954            message_id,
3955            session.user_id(),
3956            request.body.as_str(),
3957            &request.mentions,
3958            updated_at,
3959        )
3960        .await?;
3961
3962    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3963
3964    let message = proto::ChannelMessage {
3965        sender_id: session.user_id().to_proto(),
3966        id: message_id.to_proto(),
3967        body: request.body.clone(),
3968        mentions: request.mentions.clone(),
3969        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3970        nonce: Some(nonce),
3971        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3972        edited_at: Some(updated_at.unix_timestamp() as u64),
3973    };
3974
3975    response.send(proto::Ack {})?;
3976
3977    let pool = &*session.connection_pool().await;
3978    broadcast(
3979        Some(session.connection_id),
3980        participant_connection_ids,
3981        |connection| {
3982            session.peer.send(
3983                connection,
3984                proto::ChannelMessageUpdate {
3985                    channel_id: channel_id.to_proto(),
3986                    message: Some(message.clone()),
3987                },
3988            )?;
3989
3990            for notification_id in &deleted_mention_notification_ids {
3991                session.peer.send(
3992                    connection,
3993                    proto::DeleteNotification {
3994                        notification_id: (*notification_id).to_proto(),
3995                    },
3996                )?;
3997            }
3998
3999            for notification in &updated_mention_notifications {
4000                session.peer.send(
4001                    connection,
4002                    proto::UpdateNotification {
4003                        notification: Some(notification.clone()),
4004                    },
4005                )?;
4006            }
4007
4008            Ok(())
4009        },
4010    );
4011
4012    send_notifications(pool, &session.peer, notifications);
4013
4014    Ok(())
4015}
4016
4017/// Mark a channel message as read
4018async fn acknowledge_channel_message(
4019    request: proto::AckChannelMessage,
4020    session: MessageContext,
4021) -> Result<()> {
4022    let channel_id = ChannelId::from_proto(request.channel_id);
4023    let message_id = MessageId::from_proto(request.message_id);
4024    let notifications = session
4025        .db()
4026        .await
4027        .observe_channel_message(channel_id, session.user_id(), message_id)
4028        .await?;
4029    send_notifications(
4030        &*session.connection_pool().await,
4031        &session.peer,
4032        notifications,
4033    );
4034    Ok(())
4035}
4036
4037/// Mark a buffer version as synced
4038async fn acknowledge_buffer_version(
4039    request: proto::AckBufferOperation,
4040    session: MessageContext,
4041) -> Result<()> {
4042    let buffer_id = BufferId::from_proto(request.buffer_id);
4043    session
4044        .db()
4045        .await
4046        .observe_buffer_version(
4047            buffer_id,
4048            session.user_id(),
4049            request.epoch as i32,
4050            &request.version,
4051        )
4052        .await?;
4053    Ok(())
4054}
4055
4056/// Get a Supermaven API key for the user
4057async fn get_supermaven_api_key(
4058    _request: proto::GetSupermavenApiKey,
4059    response: Response<proto::GetSupermavenApiKey>,
4060    session: MessageContext,
4061) -> Result<()> {
4062    let user_id: String = session.user_id().to_string();
4063    if !session.is_staff() {
4064        return Err(anyhow!("supermaven not enabled for this account"))?;
4065    }
4066
4067    let email = session.email().context("user must have an email")?;
4068
4069    let supermaven_admin_api = session
4070        .supermaven_client
4071        .as_ref()
4072        .context("supermaven not configured")?;
4073
4074    let result = supermaven_admin_api
4075        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4076        .await?;
4077
4078    response.send(proto::GetSupermavenApiKeyResponse {
4079        api_key: result.api_key,
4080    })?;
4081
4082    Ok(())
4083}
4084
4085/// Start receiving chat updates for a channel
4086async fn join_channel_chat(
4087    request: proto::JoinChannelChat,
4088    response: Response<proto::JoinChannelChat>,
4089    session: MessageContext,
4090) -> Result<()> {
4091    let channel_id = ChannelId::from_proto(request.channel_id);
4092
4093    let db = session.db().await;
4094    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4095        .await?;
4096    let messages = db
4097        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4098        .await?;
4099    response.send(proto::JoinChannelChatResponse {
4100        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4101        messages,
4102    })?;
4103    Ok(())
4104}
4105
4106/// Stop receiving chat updates for a channel
4107async fn leave_channel_chat(
4108    request: proto::LeaveChannelChat,
4109    session: MessageContext,
4110) -> Result<()> {
4111    let channel_id = ChannelId::from_proto(request.channel_id);
4112    session
4113        .db()
4114        .await
4115        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4116        .await?;
4117    Ok(())
4118}
4119
4120/// Retrieve the chat history for a channel
4121async fn get_channel_messages(
4122    request: proto::GetChannelMessages,
4123    response: Response<proto::GetChannelMessages>,
4124    session: MessageContext,
4125) -> Result<()> {
4126    let channel_id = ChannelId::from_proto(request.channel_id);
4127    let messages = session
4128        .db()
4129        .await
4130        .get_channel_messages(
4131            channel_id,
4132            session.user_id(),
4133            MESSAGE_COUNT_PER_PAGE,
4134            Some(MessageId::from_proto(request.before_message_id)),
4135        )
4136        .await?;
4137    response.send(proto::GetChannelMessagesResponse {
4138        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4139        messages,
4140    })?;
4141    Ok(())
4142}
4143
4144/// Retrieve specific chat messages
4145async fn get_channel_messages_by_id(
4146    request: proto::GetChannelMessagesById,
4147    response: Response<proto::GetChannelMessagesById>,
4148    session: MessageContext,
4149) -> Result<()> {
4150    let message_ids = request
4151        .message_ids
4152        .iter()
4153        .map(|id| MessageId::from_proto(*id))
4154        .collect::<Vec<_>>();
4155    let messages = session
4156        .db()
4157        .await
4158        .get_channel_messages_by_id(session.user_id(), &message_ids)
4159        .await?;
4160    response.send(proto::GetChannelMessagesResponse {
4161        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4162        messages,
4163    })?;
4164    Ok(())
4165}
4166
4167/// Retrieve the current users notifications
4168async fn get_notifications(
4169    request: proto::GetNotifications,
4170    response: Response<proto::GetNotifications>,
4171    session: MessageContext,
4172) -> Result<()> {
4173    let notifications = session
4174        .db()
4175        .await
4176        .get_notifications(
4177            session.user_id(),
4178            NOTIFICATION_COUNT_PER_PAGE,
4179            request.before_id.map(db::NotificationId::from_proto),
4180        )
4181        .await?;
4182    response.send(proto::GetNotificationsResponse {
4183        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4184        notifications,
4185    })?;
4186    Ok(())
4187}
4188
4189/// Mark notifications as read
4190async fn mark_notification_as_read(
4191    request: proto::MarkNotificationRead,
4192    response: Response<proto::MarkNotificationRead>,
4193    session: MessageContext,
4194) -> Result<()> {
4195    let database = &session.db().await;
4196    let notifications = database
4197        .mark_notification_as_read_by_id(
4198            session.user_id(),
4199            NotificationId::from_proto(request.notification_id),
4200        )
4201        .await?;
4202    send_notifications(
4203        &*session.connection_pool().await,
4204        &session.peer,
4205        notifications,
4206    );
4207    response.send(proto::Ack {})?;
4208    Ok(())
4209}
4210
4211/// Accept the terms of service (tos) on behalf of the current user
4212async fn accept_terms_of_service(
4213    _request: proto::AcceptTermsOfService,
4214    response: Response<proto::AcceptTermsOfService>,
4215    session: MessageContext,
4216) -> Result<()> {
4217    let db = session.db().await;
4218
4219    let accepted_tos_at = Utc::now();
4220    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4221        .await?;
4222
4223    response.send(proto::AcceptTermsOfServiceResponse {
4224        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4225    })?;
4226
4227    Ok(())
4228}
4229
4230fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4231    let message = match message {
4232        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4233        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4234        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4235        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4236        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4237            code: frame.code.into(),
4238            reason: frame.reason.as_str().to_owned().into(),
4239        })),
4240        // We should never receive a frame while reading the message, according
4241        // to the `tungstenite` maintainers:
4242        //
4243        // > It cannot occur when you read messages from the WebSocket, but it
4244        // > can be used when you want to send the raw frames (e.g. you want to
4245        // > send the frames to the WebSocket without composing the full message first).
4246        // >
4247        // > — https://github.com/snapview/tungstenite-rs/issues/268
4248        TungsteniteMessage::Frame(_) => {
4249            bail!("received an unexpected frame while reading the message")
4250        }
4251    };
4252
4253    Ok(message)
4254}
4255
4256fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4257    match message {
4258        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4259        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4260        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4261        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4262        AxumMessage::Close(frame) => {
4263            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4264                code: frame.code.into(),
4265                reason: frame.reason.as_ref().into(),
4266            }))
4267        }
4268    }
4269}
4270
4271fn notify_membership_updated(
4272    connection_pool: &mut ConnectionPool,
4273    result: MembershipUpdated,
4274    user_id: UserId,
4275    peer: &Peer,
4276) {
4277    for membership in &result.new_channels.channel_memberships {
4278        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4279    }
4280    for channel_id in &result.removed_channels {
4281        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4282    }
4283
4284    let user_channels_update = proto::UpdateUserChannels {
4285        channel_memberships: result
4286            .new_channels
4287            .channel_memberships
4288            .iter()
4289            .map(|cm| proto::ChannelMembership {
4290                channel_id: cm.channel_id.to_proto(),
4291                role: cm.role.into(),
4292            })
4293            .collect(),
4294        ..Default::default()
4295    };
4296
4297    let mut update = build_channels_update(result.new_channels);
4298    update.delete_channels = result
4299        .removed_channels
4300        .into_iter()
4301        .map(|id| id.to_proto())
4302        .collect();
4303    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4304
4305    for connection_id in connection_pool.user_connection_ids(user_id) {
4306        peer.send(connection_id, user_channels_update.clone())
4307            .trace_err();
4308        peer.send(connection_id, update.clone()).trace_err();
4309    }
4310}
4311
4312fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4313    proto::UpdateUserChannels {
4314        channel_memberships: channels
4315            .channel_memberships
4316            .iter()
4317            .map(|m| proto::ChannelMembership {
4318                channel_id: m.channel_id.to_proto(),
4319                role: m.role.into(),
4320            })
4321            .collect(),
4322        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4323        observed_channel_message_id: channels.observed_channel_messages.clone(),
4324    }
4325}
4326
4327fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4328    let mut update = proto::UpdateChannels::default();
4329
4330    for channel in channels.channels {
4331        update.channels.push(channel.to_proto());
4332    }
4333
4334    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4335    update.latest_channel_message_ids = channels.latest_channel_messages;
4336
4337    for (channel_id, participants) in channels.channel_participants {
4338        update
4339            .channel_participants
4340            .push(proto::ChannelParticipants {
4341                channel_id: channel_id.to_proto(),
4342                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4343            });
4344    }
4345
4346    for channel in channels.invited_channels {
4347        update.channel_invitations.push(channel.to_proto());
4348    }
4349
4350    update
4351}
4352
4353fn build_initial_contacts_update(
4354    contacts: Vec<db::Contact>,
4355    pool: &ConnectionPool,
4356) -> proto::UpdateContacts {
4357    let mut update = proto::UpdateContacts::default();
4358
4359    for contact in contacts {
4360        match contact {
4361            db::Contact::Accepted { user_id, busy } => {
4362                update.contacts.push(contact_for_user(user_id, busy, pool));
4363            }
4364            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4365            db::Contact::Incoming { user_id } => {
4366                update
4367                    .incoming_requests
4368                    .push(proto::IncomingContactRequest {
4369                        requester_id: user_id.to_proto(),
4370                    })
4371            }
4372        }
4373    }
4374
4375    update
4376}
4377
4378fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4379    proto::Contact {
4380        user_id: user_id.to_proto(),
4381        online: pool.is_user_online(user_id),
4382        busy,
4383    }
4384}
4385
4386fn room_updated(room: &proto::Room, peer: &Peer) {
4387    broadcast(
4388        None,
4389        room.participants
4390            .iter()
4391            .filter_map(|participant| Some(participant.peer_id?.into())),
4392        |peer_id| {
4393            peer.send(
4394                peer_id,
4395                proto::RoomUpdated {
4396                    room: Some(room.clone()),
4397                },
4398            )
4399        },
4400    );
4401}
4402
4403fn channel_updated(
4404    channel: &db::channel::Model,
4405    room: &proto::Room,
4406    peer: &Peer,
4407    pool: &ConnectionPool,
4408) {
4409    let participants = room
4410        .participants
4411        .iter()
4412        .map(|p| p.user_id)
4413        .collect::<Vec<_>>();
4414
4415    broadcast(
4416        None,
4417        pool.channel_connection_ids(channel.root_id())
4418            .filter_map(|(channel_id, role)| {
4419                role.can_see_channel(channel.visibility)
4420                    .then_some(channel_id)
4421            }),
4422        |peer_id| {
4423            peer.send(
4424                peer_id,
4425                proto::UpdateChannels {
4426                    channel_participants: vec![proto::ChannelParticipants {
4427                        channel_id: channel.id.to_proto(),
4428                        participant_user_ids: participants.clone(),
4429                    }],
4430                    ..Default::default()
4431                },
4432            )
4433        },
4434    );
4435}
4436
4437async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4438    let db = session.db().await;
4439
4440    let contacts = db.get_contacts(user_id).await?;
4441    let busy = db.is_user_busy(user_id).await?;
4442
4443    let pool = session.connection_pool().await;
4444    let updated_contact = contact_for_user(user_id, busy, &pool);
4445    for contact in contacts {
4446        if let db::Contact::Accepted {
4447            user_id: contact_user_id,
4448            ..
4449        } = contact
4450        {
4451            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4452                session
4453                    .peer
4454                    .send(
4455                        contact_conn_id,
4456                        proto::UpdateContacts {
4457                            contacts: vec![updated_contact.clone()],
4458                            remove_contacts: Default::default(),
4459                            incoming_requests: Default::default(),
4460                            remove_incoming_requests: Default::default(),
4461                            outgoing_requests: Default::default(),
4462                            remove_outgoing_requests: Default::default(),
4463                        },
4464                    )
4465                    .trace_err();
4466            }
4467        }
4468    }
4469    Ok(())
4470}
4471
4472async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4473    let mut contacts_to_update = HashSet::default();
4474
4475    let room_id;
4476    let canceled_calls_to_user_ids;
4477    let livekit_room;
4478    let delete_livekit_room;
4479    let room;
4480    let channel;
4481
4482    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4483        contacts_to_update.insert(session.user_id());
4484
4485        for project in left_room.left_projects.values() {
4486            project_left(project, session);
4487        }
4488
4489        room_id = RoomId::from_proto(left_room.room.id);
4490        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4491        livekit_room = mem::take(&mut left_room.room.livekit_room);
4492        delete_livekit_room = left_room.deleted;
4493        room = mem::take(&mut left_room.room);
4494        channel = mem::take(&mut left_room.channel);
4495
4496        room_updated(&room, &session.peer);
4497    } else {
4498        return Ok(());
4499    }
4500
4501    if let Some(channel) = channel {
4502        channel_updated(
4503            &channel,
4504            &room,
4505            &session.peer,
4506            &*session.connection_pool().await,
4507        );
4508    }
4509
4510    {
4511        let pool = session.connection_pool().await;
4512        for canceled_user_id in canceled_calls_to_user_ids {
4513            for connection_id in pool.user_connection_ids(canceled_user_id) {
4514                session
4515                    .peer
4516                    .send(
4517                        connection_id,
4518                        proto::CallCanceled {
4519                            room_id: room_id.to_proto(),
4520                        },
4521                    )
4522                    .trace_err();
4523            }
4524            contacts_to_update.insert(canceled_user_id);
4525        }
4526    }
4527
4528    for contact_user_id in contacts_to_update {
4529        update_user_contacts(contact_user_id, session).await?;
4530    }
4531
4532    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4533        live_kit
4534            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4535            .await
4536            .trace_err();
4537
4538        if delete_livekit_room {
4539            live_kit.delete_room(livekit_room).await.trace_err();
4540        }
4541    }
4542
4543    Ok(())
4544}
4545
4546async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4547    let left_channel_buffers = session
4548        .db()
4549        .await
4550        .leave_channel_buffers(session.connection_id)
4551        .await?;
4552
4553    for left_buffer in left_channel_buffers {
4554        channel_buffer_updated(
4555            session.connection_id,
4556            left_buffer.connections,
4557            &proto::UpdateChannelBufferCollaborators {
4558                channel_id: left_buffer.channel_id.to_proto(),
4559                collaborators: left_buffer.collaborators,
4560            },
4561            &session.peer,
4562        );
4563    }
4564
4565    Ok(())
4566}
4567
4568fn project_left(project: &db::LeftProject, session: &Session) {
4569    for connection_id in &project.connection_ids {
4570        if project.should_unshare {
4571            session
4572                .peer
4573                .send(
4574                    *connection_id,
4575                    proto::UnshareProject {
4576                        project_id: project.id.to_proto(),
4577                    },
4578                )
4579                .trace_err();
4580        } else {
4581            session
4582                .peer
4583                .send(
4584                    *connection_id,
4585                    proto::RemoveProjectCollaborator {
4586                        project_id: project.id.to_proto(),
4587                        peer_id: Some(session.connection_id.into()),
4588                    },
4589                )
4590                .trace_err();
4591        }
4592    }
4593}
4594
4595pub trait ResultExt {
4596    type Ok;
4597
4598    fn trace_err(self) -> Option<Self::Ok>;
4599}
4600
4601impl<T, E> ResultExt for Result<T, E>
4602where
4603    E: std::fmt::Debug,
4604{
4605    type Ok = T;
4606
4607    #[track_caller]
4608    fn trace_err(self) -> Option<T> {
4609        match self {
4610            Ok(value) => Some(value),
4611            Err(error) => {
4612                tracing::error!("{:?}", error);
4613                None
4614            }
4615        }
4616    }
4617}