rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
  10    },
  11    executor::Executor,
  12};
  13use anyhow::{Context as _, anyhow, bail};
  14use async_tungstenite::tungstenite::{
  15    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  16};
  17use axum::headers::UserAgent;
  18use axum::{
  19    Extension, Router, TypedHeader,
  20    body::Body,
  21    extract::{
  22        ConnectInfo, WebSocketUpgrade,
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use core::fmt::{self, Debug, Formatter};
  34use futures::TryFutureExt as _;
  35use reqwest_client::ReqwestClient;
  36use rpc::proto::split_repository_update;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38use tracing::Span;
  39use util::paths::PathStyle;
  40
  41use futures::{
  42    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  43    stream::FuturesUnordered,
  44};
  45use prometheus::{IntGauge, register_int_gauge};
  46use rpc::{
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52};
  53use semver::Version;
  54use serde::{Serialize, Serializer};
  55use std::{
  56    any::TypeId,
  57    future::Future,
  58    marker::PhantomData,
  59    mem,
  60    net::SocketAddr,
  61    ops::{Deref, DerefMut},
  62    rc::Rc,
  63    sync::{
  64        Arc, OnceLock,
  65        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  66    },
  67    time::{Duration, Instant},
  68};
  69use tokio::sync::{Semaphore, watch};
  70use tower::ServiceBuilder;
  71use tracing::{
  72    Instrument,
  73    field::{self},
  74    info_span, instrument,
  75};
  76
  77pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  78
  79// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  80pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  81
  82const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  83const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  84
  85static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  86
  87const TOTAL_DURATION_MS: &str = "total_duration_ms";
  88const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  89const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  90const HOST_WAITING_MS: &str = "host_waiting_ms";
  91
  92type MessageHandler =
  93    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  94
  95pub struct ConnectionGuard;
  96
  97impl ConnectionGuard {
  98    pub fn try_acquire() -> Result<Self, ()> {
  99        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 100        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 101            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 102            tracing::error!(
 103                "too many concurrent connections: {}",
 104                current_connections + 1
 105            );
 106            return Err(());
 107        }
 108        Ok(ConnectionGuard)
 109    }
 110}
 111
 112impl Drop for ConnectionGuard {
 113    fn drop(&mut self) {
 114        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 115    }
 116}
 117
 118struct Response<R> {
 119    peer: Arc<Peer>,
 120    receipt: Receipt<R>,
 121    responded: Arc<AtomicBool>,
 122}
 123
 124impl<R: RequestMessage> Response<R> {
 125    fn send(self, payload: R::Response) -> Result<()> {
 126        self.responded.store(true, SeqCst);
 127        self.peer.respond(self.receipt, payload)?;
 128        Ok(())
 129    }
 130}
 131
 132#[derive(Clone, Debug)]
 133pub enum Principal {
 134    User(User),
 135    Impersonated { user: User, admin: User },
 136}
 137
 138impl Principal {
 139    fn update_span(&self, span: &tracing::Span) {
 140        match &self {
 141            Principal::User(user) => {
 142                span.record("user_id", user.id.0);
 143                span.record("login", &user.github_login);
 144            }
 145            Principal::Impersonated { user, admin } => {
 146                span.record("user_id", user.id.0);
 147                span.record("login", &user.github_login);
 148                span.record("impersonator", &admin.github_login);
 149            }
 150        }
 151    }
 152}
 153
 154#[derive(Clone)]
 155struct MessageContext {
 156    session: Session,
 157    span: tracing::Span,
 158}
 159
 160impl Deref for MessageContext {
 161    type Target = Session;
 162
 163    fn deref(&self) -> &Self::Target {
 164        &self.session
 165    }
 166}
 167
 168impl MessageContext {
 169    pub fn forward_request<T: RequestMessage>(
 170        &self,
 171        receiver_id: ConnectionId,
 172        request: T,
 173    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 174        let request_start_time = Instant::now();
 175        let span = self.span.clone();
 176        tracing::info!("start forwarding request");
 177        self.peer
 178            .forward_request(self.connection_id, receiver_id, request)
 179            .inspect(move |_| {
 180                span.record(
 181                    HOST_WAITING_MS,
 182                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 183                );
 184            })
 185            .inspect_err(|_| tracing::error!("error forwarding request"))
 186            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 187    }
 188}
 189
 190#[derive(Clone)]
 191struct Session {
 192    principal: Principal,
 193    connection_id: ConnectionId,
 194    db: Arc<tokio::sync::Mutex<DbHandle>>,
 195    peer: Arc<Peer>,
 196    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 197    app_state: Arc<AppState>,
 198    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 199    /// The GeoIP country code for the user.
 200    #[allow(unused)]
 201    geoip_country_code: Option<String>,
 202    #[allow(unused)]
 203    system_id: Option<String>,
 204    _executor: Executor,
 205}
 206
 207impl Session {
 208    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 209        #[cfg(test)]
 210        tokio::task::yield_now().await;
 211        let guard = self.db.lock().await;
 212        #[cfg(test)]
 213        tokio::task::yield_now().await;
 214        guard
 215    }
 216
 217    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 218        #[cfg(test)]
 219        tokio::task::yield_now().await;
 220        let guard = self.connection_pool.lock();
 221        ConnectionPoolGuard {
 222            guard,
 223            _not_send: PhantomData,
 224        }
 225    }
 226
 227    fn is_staff(&self) -> bool {
 228        match &self.principal {
 229            Principal::User(user) => user.admin,
 230            Principal::Impersonated { .. } => true,
 231        }
 232    }
 233
 234    fn user_id(&self) -> UserId {
 235        match &self.principal {
 236            Principal::User(user) => user.id,
 237            Principal::Impersonated { user, .. } => user.id,
 238        }
 239    }
 240
 241    pub fn email(&self) -> Option<String> {
 242        match &self.principal {
 243            Principal::User(user) => user.email_address.clone(),
 244            Principal::Impersonated { user, .. } => user.email_address.clone(),
 245        }
 246    }
 247}
 248
 249impl Debug for Session {
 250    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 251        let mut result = f.debug_struct("Session");
 252        match &self.principal {
 253            Principal::User(user) => {
 254                result.field("user", &user.github_login);
 255            }
 256            Principal::Impersonated { user, admin } => {
 257                result.field("user", &user.github_login);
 258                result.field("impersonator", &admin.github_login);
 259            }
 260        }
 261        result.field("connection_id", &self.connection_id).finish()
 262    }
 263}
 264
 265struct DbHandle(Arc<Database>);
 266
 267impl Deref for DbHandle {
 268    type Target = Database;
 269
 270    fn deref(&self) -> &Self::Target {
 271        self.0.as_ref()
 272    }
 273}
 274
 275pub struct Server {
 276    id: parking_lot::Mutex<ServerId>,
 277    peer: Arc<Peer>,
 278    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 279    app_state: Arc<AppState>,
 280    handlers: HashMap<TypeId, MessageHandler>,
 281    teardown: watch::Sender<bool>,
 282}
 283
 284pub(crate) struct ConnectionPoolGuard<'a> {
 285    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 286    _not_send: PhantomData<Rc<()>>,
 287}
 288
 289#[derive(Serialize)]
 290pub struct ServerSnapshot<'a> {
 291    peer: &'a Peer,
 292    #[serde(serialize_with = "serialize_deref")]
 293    connection_pool: ConnectionPoolGuard<'a>,
 294}
 295
 296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 297where
 298    S: Serializer,
 299    T: Deref<Target = U>,
 300    U: Serialize,
 301{
 302    Serialize::serialize(value.deref(), serializer)
 303}
 304
 305impl Server {
 306    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 307        let mut server = Self {
 308            id: parking_lot::Mutex::new(id),
 309            peer: Peer::new(id.0 as u32),
 310            app_state,
 311            connection_pool: Default::default(),
 312            handlers: Default::default(),
 313            teardown: watch::channel(false).0,
 314        };
 315
 316        server
 317            .add_request_handler(ping)
 318            .add_request_handler(create_room)
 319            .add_request_handler(join_room)
 320            .add_request_handler(rejoin_room)
 321            .add_request_handler(leave_room)
 322            .add_request_handler(set_room_participant_role)
 323            .add_request_handler(call)
 324            .add_request_handler(cancel_call)
 325            .add_message_handler(decline_call)
 326            .add_request_handler(update_participant_location)
 327            .add_request_handler(share_project)
 328            .add_message_handler(unshare_project)
 329            .add_request_handler(join_project)
 330            .add_message_handler(leave_project)
 331            .add_request_handler(update_project)
 332            .add_request_handler(update_worktree)
 333            .add_request_handler(update_repository)
 334            .add_request_handler(remove_repository)
 335            .add_message_handler(start_language_server)
 336            .add_message_handler(update_language_server)
 337            .add_message_handler(update_diagnostic_summary)
 338            .add_message_handler(update_worktree_settings)
 339            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 340            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 341            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 342            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 343            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 344            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 345            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 346            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 347            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 348            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 349            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 350            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 351            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 352            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 353            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 354            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 355            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 356            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 357            .add_request_handler(
 358                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 359            )
 360            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 361            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 362            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 363            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 364            .add_request_handler(
 365                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 366            )
 367            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 368            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 369            .add_request_handler(
 370                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 371            )
 372            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 373            .add_request_handler(
 374                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 375            )
 376            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 377            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 378            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 379            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 380            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 381            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 382            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 383            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 384            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 385            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 386            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 387            .add_request_handler(
 388                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 389            )
 390            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 391            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 392            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 393            .add_request_handler(lsp_query)
 394            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 395            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 396            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 397            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 398            .add_message_handler(create_buffer_for_peer)
 399            .add_message_handler(create_image_for_peer)
 400            .add_request_handler(update_buffer)
 401            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 402            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 403            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 404            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 405            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 406            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 407            .add_message_handler(
 408                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 409            )
 410            .add_request_handler(get_users)
 411            .add_request_handler(fuzzy_search_users)
 412            .add_request_handler(request_contact)
 413            .add_request_handler(remove_contact)
 414            .add_request_handler(respond_to_contact_request)
 415            .add_message_handler(subscribe_to_channels)
 416            .add_request_handler(create_channel)
 417            .add_request_handler(delete_channel)
 418            .add_request_handler(invite_channel_member)
 419            .add_request_handler(remove_channel_member)
 420            .add_request_handler(set_channel_member_role)
 421            .add_request_handler(set_channel_visibility)
 422            .add_request_handler(rename_channel)
 423            .add_request_handler(join_channel_buffer)
 424            .add_request_handler(leave_channel_buffer)
 425            .add_message_handler(update_channel_buffer)
 426            .add_request_handler(rejoin_channel_buffers)
 427            .add_request_handler(get_channel_members)
 428            .add_request_handler(respond_to_channel_invite)
 429            .add_request_handler(join_channel)
 430            .add_request_handler(join_channel_chat)
 431            .add_message_handler(leave_channel_chat)
 432            .add_request_handler(send_channel_message)
 433            .add_request_handler(remove_channel_message)
 434            .add_request_handler(update_channel_message)
 435            .add_request_handler(get_channel_messages)
 436            .add_request_handler(get_channel_messages_by_id)
 437            .add_request_handler(get_notifications)
 438            .add_request_handler(mark_notification_as_read)
 439            .add_request_handler(move_channel)
 440            .add_request_handler(reorder_channel)
 441            .add_request_handler(follow)
 442            .add_message_handler(unfollow)
 443            .add_message_handler(update_followers)
 444            .add_message_handler(acknowledge_channel_message)
 445            .add_message_handler(acknowledge_buffer_version)
 446            .add_request_handler(get_supermaven_api_key)
 447            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 448            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 449            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 450            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 451            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 452            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 453            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 454            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 455            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 456            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 457            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 458            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 459            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 460            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 461            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 462            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 463            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 464            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 465            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 466            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 467            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 468            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 469            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 470            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 471            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 472            .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
 473            .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
 474            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 475            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 476            .add_message_handler(update_context)
 477            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 478            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
 479
 480        Arc::new(server)
 481    }
 482
 483    pub async fn start(&self) -> Result<()> {
 484        let server_id = *self.id.lock();
 485        let app_state = self.app_state.clone();
 486        let peer = self.peer.clone();
 487        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 488        let pool = self.connection_pool.clone();
 489        let livekit_client = self.app_state.livekit_client.clone();
 490
 491        let span = info_span!("start server");
 492        self.app_state.executor.spawn_detached(
 493            async move {
 494                tracing::info!("waiting for cleanup timeout");
 495                timeout.await;
 496                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 497
 498                app_state
 499                    .db
 500                    .delete_stale_channel_chat_participants(
 501                        &app_state.config.zed_environment,
 502                        server_id,
 503                    )
 504                    .await
 505                    .trace_err();
 506
 507                if let Some((room_ids, channel_ids)) = app_state
 508                    .db
 509                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 510                    .await
 511                    .trace_err()
 512                {
 513                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 514                    tracing::info!(
 515                        stale_channel_buffer_count = channel_ids.len(),
 516                        "retrieved stale channel buffers"
 517                    );
 518
 519                    for channel_id in channel_ids {
 520                        if let Some(refreshed_channel_buffer) = app_state
 521                            .db
 522                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 523                            .await
 524                            .trace_err()
 525                        {
 526                            for connection_id in refreshed_channel_buffer.connection_ids {
 527                                peer.send(
 528                                    connection_id,
 529                                    proto::UpdateChannelBufferCollaborators {
 530                                        channel_id: channel_id.to_proto(),
 531                                        collaborators: refreshed_channel_buffer
 532                                            .collaborators
 533                                            .clone(),
 534                                    },
 535                                )
 536                                .trace_err();
 537                            }
 538                        }
 539                    }
 540
 541                    for room_id in room_ids {
 542                        let mut contacts_to_update = HashSet::default();
 543                        let mut canceled_calls_to_user_ids = Vec::new();
 544                        let mut livekit_room = String::new();
 545                        let mut delete_livekit_room = false;
 546
 547                        if let Some(mut refreshed_room) = app_state
 548                            .db
 549                            .clear_stale_room_participants(room_id, server_id)
 550                            .await
 551                            .trace_err()
 552                        {
 553                            tracing::info!(
 554                                room_id = room_id.0,
 555                                new_participant_count = refreshed_room.room.participants.len(),
 556                                "refreshed room"
 557                            );
 558                            room_updated(&refreshed_room.room, &peer);
 559                            if let Some(channel) = refreshed_room.channel.as_ref() {
 560                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 561                            }
 562                            contacts_to_update
 563                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 564                            contacts_to_update
 565                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 566                            canceled_calls_to_user_ids =
 567                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 568                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 569                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 570                        }
 571
 572                        {
 573                            let pool = pool.lock();
 574                            for canceled_user_id in canceled_calls_to_user_ids {
 575                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 576                                    peer.send(
 577                                        connection_id,
 578                                        proto::CallCanceled {
 579                                            room_id: room_id.to_proto(),
 580                                        },
 581                                    )
 582                                    .trace_err();
 583                                }
 584                            }
 585                        }
 586
 587                        for user_id in contacts_to_update {
 588                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 589                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 590                            if let Some((busy, contacts)) = busy.zip(contacts) {
 591                                let pool = pool.lock();
 592                                let updated_contact = contact_for_user(user_id, busy, &pool);
 593                                for contact in contacts {
 594                                    if let db::Contact::Accepted {
 595                                        user_id: contact_user_id,
 596                                        ..
 597                                    } = contact
 598                                    {
 599                                        for contact_conn_id in
 600                                            pool.user_connection_ids(contact_user_id)
 601                                        {
 602                                            peer.send(
 603                                                contact_conn_id,
 604                                                proto::UpdateContacts {
 605                                                    contacts: vec![updated_contact.clone()],
 606                                                    remove_contacts: Default::default(),
 607                                                    incoming_requests: Default::default(),
 608                                                    remove_incoming_requests: Default::default(),
 609                                                    outgoing_requests: Default::default(),
 610                                                    remove_outgoing_requests: Default::default(),
 611                                                },
 612                                            )
 613                                            .trace_err();
 614                                        }
 615                                    }
 616                                }
 617                            }
 618                        }
 619
 620                        if let Some(live_kit) = livekit_client.as_ref()
 621                            && delete_livekit_room
 622                        {
 623                            live_kit.delete_room(livekit_room).await.trace_err();
 624                        }
 625                    }
 626                }
 627
 628                app_state
 629                    .db
 630                    .delete_stale_channel_chat_participants(
 631                        &app_state.config.zed_environment,
 632                        server_id,
 633                    )
 634                    .await
 635                    .trace_err();
 636
 637                app_state
 638                    .db
 639                    .clear_old_worktree_entries(server_id)
 640                    .await
 641                    .trace_err();
 642
 643                app_state
 644                    .db
 645                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 646                    .await
 647                    .trace_err();
 648            }
 649            .instrument(span),
 650        );
 651        Ok(())
 652    }
 653
 654    pub fn teardown(&self) {
 655        self.peer.teardown();
 656        self.connection_pool.lock().reset();
 657        let _ = self.teardown.send(true);
 658    }
 659
 660    #[cfg(test)]
 661    pub fn reset(&self, id: ServerId) {
 662        self.teardown();
 663        *self.id.lock() = id;
 664        self.peer.reset(id.0 as u32);
 665        let _ = self.teardown.send(false);
 666    }
 667
 668    #[cfg(test)]
 669    pub fn id(&self) -> ServerId {
 670        *self.id.lock()
 671    }
 672
 673    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 674    where
 675        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 676        Fut: 'static + Send + Future<Output = Result<()>>,
 677        M: EnvelopedMessage,
 678    {
 679        let prev_handler = self.handlers.insert(
 680            TypeId::of::<M>(),
 681            Box::new(move |envelope, session, span| {
 682                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 683                let received_at = envelope.received_at;
 684                tracing::info!("message received");
 685                let start_time = Instant::now();
 686                let future = (handler)(
 687                    *envelope,
 688                    MessageContext {
 689                        session,
 690                        span: span.clone(),
 691                    },
 692                );
 693                async move {
 694                    let result = future.await;
 695                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 696                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 697                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 698                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 699                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 700                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 701                    match result {
 702                        Err(error) => {
 703                            tracing::error!(?error, "error handling message")
 704                        }
 705                        Ok(()) => tracing::info!("finished handling message"),
 706                    }
 707                }
 708                .boxed()
 709            }),
 710        );
 711        if prev_handler.is_some() {
 712            panic!("registered a handler for the same message twice");
 713        }
 714        self
 715    }
 716
 717    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 718    where
 719        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 720        Fut: 'static + Send + Future<Output = Result<()>>,
 721        M: EnvelopedMessage,
 722    {
 723        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 724        self
 725    }
 726
 727    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 728    where
 729        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 730        Fut: Send + Future<Output = Result<()>>,
 731        M: RequestMessage,
 732    {
 733        let handler = Arc::new(handler);
 734        self.add_handler(move |envelope, session| {
 735            let receipt = envelope.receipt();
 736            let handler = handler.clone();
 737            async move {
 738                let peer = session.peer.clone();
 739                let responded = Arc::new(AtomicBool::default());
 740                let response = Response {
 741                    peer: peer.clone(),
 742                    responded: responded.clone(),
 743                    receipt,
 744                };
 745                match (handler)(envelope.payload, response, session).await {
 746                    Ok(()) => {
 747                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 748                            Ok(())
 749                        } else {
 750                            Err(anyhow!("handler did not send a response"))?
 751                        }
 752                    }
 753                    Err(error) => {
 754                        let proto_err = match &error {
 755                            Error::Internal(err) => err.to_proto(),
 756                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 757                        };
 758                        peer.respond_with_error(receipt, proto_err)?;
 759                        Err(error)
 760                    }
 761                }
 762            }
 763        })
 764    }
 765
 766    pub fn handle_connection(
 767        self: &Arc<Self>,
 768        connection: Connection,
 769        address: String,
 770        principal: Principal,
 771        zed_version: ZedVersion,
 772        release_channel: Option<String>,
 773        user_agent: Option<String>,
 774        geoip_country_code: Option<String>,
 775        system_id: Option<String>,
 776        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 777        executor: Executor,
 778        connection_guard: Option<ConnectionGuard>,
 779    ) -> impl Future<Output = ()> + use<> {
 780        let this = self.clone();
 781        let span = info_span!("handle connection", %address,
 782            connection_id=field::Empty,
 783            user_id=field::Empty,
 784            login=field::Empty,
 785            impersonator=field::Empty,
 786            user_agent=field::Empty,
 787            geoip_country_code=field::Empty,
 788            release_channel=field::Empty,
 789        );
 790        principal.update_span(&span);
 791        if let Some(user_agent) = user_agent {
 792            span.record("user_agent", user_agent);
 793        }
 794        if let Some(release_channel) = release_channel {
 795            span.record("release_channel", release_channel);
 796        }
 797
 798        if let Some(country_code) = geoip_country_code.as_ref() {
 799            span.record("geoip_country_code", country_code);
 800        }
 801
 802        let mut teardown = self.teardown.subscribe();
 803        async move {
 804            if *teardown.borrow() {
 805                tracing::error!("server is tearing down");
 806                return;
 807            }
 808
 809            let (connection_id, handle_io, mut incoming_rx) =
 810                this.peer.add_connection(connection, {
 811                    let executor = executor.clone();
 812                    move |duration| executor.sleep(duration)
 813                });
 814            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 815
 816            tracing::info!("connection opened");
 817
 818            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 819            let http_client = match ReqwestClient::user_agent(&user_agent) {
 820                Ok(http_client) => Arc::new(http_client),
 821                Err(error) => {
 822                    tracing::error!(?error, "failed to create HTTP client");
 823                    return;
 824                }
 825            };
 826
 827            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 828                |supermaven_admin_api_key| {
 829                    Arc::new(SupermavenAdminApi::new(
 830                        supermaven_admin_api_key.to_string(),
 831                        http_client.clone(),
 832                    ))
 833                },
 834            );
 835
 836            let session = Session {
 837                principal: principal.clone(),
 838                connection_id,
 839                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 840                peer: this.peer.clone(),
 841                connection_pool: this.connection_pool.clone(),
 842                app_state: this.app_state.clone(),
 843                geoip_country_code,
 844                system_id,
 845                _executor: executor.clone(),
 846                supermaven_client,
 847            };
 848
 849            if let Err(error) = this
 850                .send_initial_client_update(
 851                    connection_id,
 852                    zed_version,
 853                    send_connection_id,
 854                    &session,
 855                )
 856                .await
 857            {
 858                tracing::error!(?error, "failed to send initial client update");
 859                return;
 860            }
 861            drop(connection_guard);
 862
 863            let handle_io = handle_io.fuse();
 864            futures::pin_mut!(handle_io);
 865
 866            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 867            // This prevents deadlocks when e.g., client A performs a request to client B and
 868            // client B performs a request to client A. If both clients stop processing further
 869            // messages until their respective request completes, they won't have a chance to
 870            // respond to the other client's request and cause a deadlock.
 871            //
 872            // This arrangement ensures we will attempt to process earlier messages first, but fall
 873            // back to processing messages arrived later in the spirit of making progress.
 874            const MAX_CONCURRENT_HANDLERS: usize = 256;
 875            let mut foreground_message_handlers = FuturesUnordered::new();
 876            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 877            let get_concurrent_handlers = {
 878                let concurrent_handlers = concurrent_handlers.clone();
 879                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 880            };
 881            loop {
 882                let next_message = async {
 883                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 884                    let message = incoming_rx.next().await;
 885                    // Cache the concurrent_handlers here, so that we know what the
 886                    // queue looks like as each handler starts
 887                    (permit, message, get_concurrent_handlers())
 888                }
 889                .fuse();
 890                futures::pin_mut!(next_message);
 891                futures::select_biased! {
 892                    _ = teardown.changed().fuse() => return,
 893                    result = handle_io => {
 894                        if let Err(error) = result {
 895                            tracing::error!(?error, "error handling I/O");
 896                        }
 897                        break;
 898                    }
 899                    _ = foreground_message_handlers.next() => {}
 900                    next_message = next_message => {
 901                        let (permit, message, concurrent_handlers) = next_message;
 902                        if let Some(message) = message {
 903                            let type_name = message.payload_type_name();
 904                            // note: we copy all the fields from the parent span so we can query them in the logs.
 905                            // (https://github.com/tokio-rs/tracing/issues/2670).
 906                            let span = tracing::info_span!("receive message",
 907                                %connection_id,
 908                                %address,
 909                                type_name,
 910                                concurrent_handlers,
 911                                user_id=field::Empty,
 912                                login=field::Empty,
 913                                impersonator=field::Empty,
 914                                lsp_query_request=field::Empty,
 915                                release_channel=field::Empty,
 916                                { TOTAL_DURATION_MS }=field::Empty,
 917                                { PROCESSING_DURATION_MS }=field::Empty,
 918                                { QUEUE_DURATION_MS }=field::Empty,
 919                                { HOST_WAITING_MS }=field::Empty
 920                            );
 921                            principal.update_span(&span);
 922                            let span_enter = span.enter();
 923                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 924                                let is_background = message.is_background();
 925                                let handle_message = (handler)(message, session.clone(), span.clone());
 926                                drop(span_enter);
 927
 928                                let handle_message = async move {
 929                                    handle_message.await;
 930                                    drop(permit);
 931                                }.instrument(span);
 932                                if is_background {
 933                                    executor.spawn_detached(handle_message);
 934                                } else {
 935                                    foreground_message_handlers.push(handle_message);
 936                                }
 937                            } else {
 938                                tracing::error!("no message handler");
 939                            }
 940                        } else {
 941                            tracing::info!("connection closed");
 942                            break;
 943                        }
 944                    }
 945                }
 946            }
 947
 948            drop(foreground_message_handlers);
 949            let concurrent_handlers = get_concurrent_handlers();
 950            tracing::info!(concurrent_handlers, "signing out");
 951            if let Err(error) = connection_lost(session, teardown, executor).await {
 952                tracing::error!(?error, "error signing out");
 953            }
 954        }
 955        .instrument(span)
 956    }
 957
 958    async fn send_initial_client_update(
 959        &self,
 960        connection_id: ConnectionId,
 961        zed_version: ZedVersion,
 962        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 963        session: &Session,
 964    ) -> Result<()> {
 965        self.peer.send(
 966            connection_id,
 967            proto::Hello {
 968                peer_id: Some(connection_id.into()),
 969            },
 970        )?;
 971        tracing::info!("sent hello message");
 972        if let Some(send_connection_id) = send_connection_id.take() {
 973            let _ = send_connection_id.send(connection_id);
 974        }
 975
 976        match &session.principal {
 977            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 978                if !user.connected_once {
 979                    self.peer.send(connection_id, proto::ShowContacts {})?;
 980                    self.app_state
 981                        .db
 982                        .set_user_connected_once(user.id, true)
 983                        .await?;
 984                }
 985
 986                let contacts = self.app_state.db.get_contacts(user.id).await?;
 987
 988                {
 989                    let mut pool = self.connection_pool.lock();
 990                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 991                    self.peer.send(
 992                        connection_id,
 993                        build_initial_contacts_update(contacts, &pool),
 994                    )?;
 995                }
 996
 997                if should_auto_subscribe_to_channels(&zed_version) {
 998                    subscribe_user_to_channels(user.id, session).await?;
 999                }
1000
1001                if let Some(incoming_call) =
1002                    self.app_state.db.incoming_call_for_user(user.id).await?
1003                {
1004                    self.peer.send(connection_id, incoming_call)?;
1005                }
1006
1007                update_user_contacts(user.id, session).await?;
1008            }
1009        }
1010
1011        Ok(())
1012    }
1013
1014    pub async fn invite_code_redeemed(
1015        self: &Arc<Self>,
1016        inviter_id: UserId,
1017        invitee_id: UserId,
1018    ) -> Result<()> {
1019        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1020            && let Some(code) = &user.invite_code
1021        {
1022            let pool = self.connection_pool.lock();
1023            let invitee_contact = contact_for_user(invitee_id, false, &pool);
1024            for connection_id in pool.user_connection_ids(inviter_id) {
1025                self.peer.send(
1026                    connection_id,
1027                    proto::UpdateContacts {
1028                        contacts: vec![invitee_contact.clone()],
1029                        ..Default::default()
1030                    },
1031                )?;
1032                self.peer.send(
1033                    connection_id,
1034                    proto::UpdateInviteInfo {
1035                        url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1036                        count: user.invite_count as u32,
1037                    },
1038                )?;
1039            }
1040        }
1041        Ok(())
1042    }
1043
1044    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1045        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1046            && let Some(invite_code) = &user.invite_code
1047        {
1048            let pool = self.connection_pool.lock();
1049            for connection_id in pool.user_connection_ids(user_id) {
1050                self.peer.send(
1051                    connection_id,
1052                    proto::UpdateInviteInfo {
1053                        url: format!(
1054                            "{}{}",
1055                            self.app_state.config.invite_link_prefix, invite_code
1056                        ),
1057                        count: user.invite_count as u32,
1058                    },
1059                )?;
1060            }
1061        }
1062        Ok(())
1063    }
1064
1065    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1066        ServerSnapshot {
1067            connection_pool: ConnectionPoolGuard {
1068                guard: self.connection_pool.lock(),
1069                _not_send: PhantomData,
1070            },
1071            peer: &self.peer,
1072        }
1073    }
1074}
1075
1076impl Deref for ConnectionPoolGuard<'_> {
1077    type Target = ConnectionPool;
1078
1079    fn deref(&self) -> &Self::Target {
1080        &self.guard
1081    }
1082}
1083
1084impl DerefMut for ConnectionPoolGuard<'_> {
1085    fn deref_mut(&mut self) -> &mut Self::Target {
1086        &mut self.guard
1087    }
1088}
1089
1090impl Drop for ConnectionPoolGuard<'_> {
1091    fn drop(&mut self) {
1092        #[cfg(test)]
1093        self.check_invariants();
1094    }
1095}
1096
1097fn broadcast<F>(
1098    sender_id: Option<ConnectionId>,
1099    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1100    mut f: F,
1101) where
1102    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1103{
1104    for receiver_id in receiver_ids {
1105        if Some(receiver_id) != sender_id
1106            && let Err(error) = f(receiver_id)
1107        {
1108            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1109        }
1110    }
1111}
1112
1113pub struct ProtocolVersion(u32);
1114
1115impl Header for ProtocolVersion {
1116    fn name() -> &'static HeaderName {
1117        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1118        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1119    }
1120
1121    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1122    where
1123        Self: Sized,
1124        I: Iterator<Item = &'i axum::http::HeaderValue>,
1125    {
1126        let version = values
1127            .next()
1128            .ok_or_else(axum::headers::Error::invalid)?
1129            .to_str()
1130            .map_err(|_| axum::headers::Error::invalid())?
1131            .parse()
1132            .map_err(|_| axum::headers::Error::invalid())?;
1133        Ok(Self(version))
1134    }
1135
1136    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1137        values.extend([self.0.to_string().parse().unwrap()]);
1138    }
1139}
1140
1141pub struct AppVersionHeader(Version);
1142impl Header for AppVersionHeader {
1143    fn name() -> &'static HeaderName {
1144        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1145        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1146    }
1147
1148    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1149    where
1150        Self: Sized,
1151        I: Iterator<Item = &'i axum::http::HeaderValue>,
1152    {
1153        let version = values
1154            .next()
1155            .ok_or_else(axum::headers::Error::invalid)?
1156            .to_str()
1157            .map_err(|_| axum::headers::Error::invalid())?
1158            .parse()
1159            .map_err(|_| axum::headers::Error::invalid())?;
1160        Ok(Self(version))
1161    }
1162
1163    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1164        values.extend([self.0.to_string().parse().unwrap()]);
1165    }
1166}
1167
1168#[derive(Debug)]
1169pub struct ReleaseChannelHeader(String);
1170
1171impl Header for ReleaseChannelHeader {
1172    fn name() -> &'static HeaderName {
1173        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1174        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1175    }
1176
1177    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1178    where
1179        Self: Sized,
1180        I: Iterator<Item = &'i axum::http::HeaderValue>,
1181    {
1182        Ok(Self(
1183            values
1184                .next()
1185                .ok_or_else(axum::headers::Error::invalid)?
1186                .to_str()
1187                .map_err(|_| axum::headers::Error::invalid())?
1188                .to_owned(),
1189        ))
1190    }
1191
1192    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1193        values.extend([self.0.parse().unwrap()]);
1194    }
1195}
1196
1197pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1198    Router::new()
1199        .route("/rpc", get(handle_websocket_request))
1200        .layer(
1201            ServiceBuilder::new()
1202                .layer(Extension(server.app_state.clone()))
1203                .layer(middleware::from_fn(auth::validate_header)),
1204        )
1205        .route("/metrics", get(handle_metrics))
1206        .layer(Extension(server))
1207}
1208
1209pub async fn handle_websocket_request(
1210    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1211    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1212    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1213    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1214    Extension(server): Extension<Arc<Server>>,
1215    Extension(principal): Extension<Principal>,
1216    user_agent: Option<TypedHeader<UserAgent>>,
1217    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1218    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1219    ws: WebSocketUpgrade,
1220) -> axum::response::Response {
1221    if protocol_version != rpc::PROTOCOL_VERSION {
1222        return (
1223            StatusCode::UPGRADE_REQUIRED,
1224            "client must be upgraded".to_string(),
1225        )
1226            .into_response();
1227    }
1228
1229    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1230        return (
1231            StatusCode::UPGRADE_REQUIRED,
1232            "no version header found".to_string(),
1233        )
1234            .into_response();
1235    };
1236
1237    let release_channel = release_channel_header.map(|header| header.0.0);
1238
1239    if !version.can_collaborate() {
1240        return (
1241            StatusCode::UPGRADE_REQUIRED,
1242            "client must be upgraded".to_string(),
1243        )
1244            .into_response();
1245    }
1246
1247    let socket_address = socket_address.to_string();
1248
1249    // Acquire connection guard before WebSocket upgrade
1250    let connection_guard = match ConnectionGuard::try_acquire() {
1251        Ok(guard) => guard,
1252        Err(()) => {
1253            return (
1254                StatusCode::SERVICE_UNAVAILABLE,
1255                "Too many concurrent connections",
1256            )
1257                .into_response();
1258        }
1259    };
1260
1261    ws.on_upgrade(move |socket| {
1262        let socket = socket
1263            .map_ok(to_tungstenite_message)
1264            .err_into()
1265            .with(|message| async move { to_axum_message(message) });
1266        let connection = Connection::new(Box::pin(socket));
1267        async move {
1268            server
1269                .handle_connection(
1270                    connection,
1271                    socket_address,
1272                    principal,
1273                    version,
1274                    release_channel,
1275                    user_agent.map(|header| header.to_string()),
1276                    country_code_header.map(|header| header.to_string()),
1277                    system_id_header.map(|header| header.to_string()),
1278                    None,
1279                    Executor::Production,
1280                    Some(connection_guard),
1281                )
1282                .await;
1283        }
1284    })
1285}
1286
1287pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1288    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1289    let connections_metric = CONNECTIONS_METRIC
1290        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1291
1292    let connections = server
1293        .connection_pool
1294        .lock()
1295        .connections()
1296        .filter(|connection| !connection.admin)
1297        .count();
1298    connections_metric.set(connections as _);
1299
1300    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1301    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1302        register_int_gauge!(
1303            "shared_projects",
1304            "number of open projects with one or more guests"
1305        )
1306        .unwrap()
1307    });
1308
1309    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1310    shared_projects_metric.set(shared_projects as _);
1311
1312    let encoder = prometheus::TextEncoder::new();
1313    let metric_families = prometheus::gather();
1314    let encoded_metrics = encoder
1315        .encode_to_string(&metric_families)
1316        .map_err(|err| anyhow!("{err}"))?;
1317    Ok(encoded_metrics)
1318}
1319
1320#[instrument(err, skip(executor))]
1321async fn connection_lost(
1322    session: Session,
1323    mut teardown: watch::Receiver<bool>,
1324    executor: Executor,
1325) -> Result<()> {
1326    session.peer.disconnect(session.connection_id);
1327    session
1328        .connection_pool()
1329        .await
1330        .remove_connection(session.connection_id)?;
1331
1332    session
1333        .db()
1334        .await
1335        .connection_lost(session.connection_id)
1336        .await
1337        .trace_err();
1338
1339    futures::select_biased! {
1340        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1341
1342            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1343            leave_room_for_session(&session, session.connection_id).await.trace_err();
1344            leave_channel_buffers_for_session(&session)
1345                .await
1346                .trace_err();
1347
1348            if !session
1349                .connection_pool()
1350                .await
1351                .is_user_online(session.user_id())
1352            {
1353                let db = session.db().await;
1354                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1355                    room_updated(&room, &session.peer);
1356                }
1357            }
1358
1359            update_user_contacts(session.user_id(), &session).await?;
1360        },
1361        _ = teardown.changed().fuse() => {}
1362    }
1363
1364    Ok(())
1365}
1366
1367/// Acknowledges a ping from a client, used to keep the connection alive.
1368async fn ping(
1369    _: proto::Ping,
1370    response: Response<proto::Ping>,
1371    _session: MessageContext,
1372) -> Result<()> {
1373    response.send(proto::Ack {})?;
1374    Ok(())
1375}
1376
1377/// Creates a new room for calling (outside of channels)
1378async fn create_room(
1379    _request: proto::CreateRoom,
1380    response: Response<proto::CreateRoom>,
1381    session: MessageContext,
1382) -> Result<()> {
1383    let livekit_room = nanoid::nanoid!(30);
1384
1385    let live_kit_connection_info = util::maybe!(async {
1386        let live_kit = session.app_state.livekit_client.as_ref();
1387        let live_kit = live_kit?;
1388        let user_id = session.user_id().to_string();
1389
1390        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1391
1392        Some(proto::LiveKitConnectionInfo {
1393            server_url: live_kit.url().into(),
1394            token,
1395            can_publish: true,
1396        })
1397    })
1398    .await;
1399
1400    let room = session
1401        .db()
1402        .await
1403        .create_room(session.user_id(), session.connection_id, &livekit_room)
1404        .await?;
1405
1406    response.send(proto::CreateRoomResponse {
1407        room: Some(room.clone()),
1408        live_kit_connection_info,
1409    })?;
1410
1411    update_user_contacts(session.user_id(), &session).await?;
1412    Ok(())
1413}
1414
1415/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1416async fn join_room(
1417    request: proto::JoinRoom,
1418    response: Response<proto::JoinRoom>,
1419    session: MessageContext,
1420) -> Result<()> {
1421    let room_id = RoomId::from_proto(request.id);
1422
1423    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1424
1425    if let Some(channel_id) = channel_id {
1426        return join_channel_internal(channel_id, Box::new(response), session).await;
1427    }
1428
1429    let joined_room = {
1430        let room = session
1431            .db()
1432            .await
1433            .join_room(room_id, session.user_id(), session.connection_id)
1434            .await?;
1435        room_updated(&room.room, &session.peer);
1436        room.into_inner()
1437    };
1438
1439    for connection_id in session
1440        .connection_pool()
1441        .await
1442        .user_connection_ids(session.user_id())
1443    {
1444        session
1445            .peer
1446            .send(
1447                connection_id,
1448                proto::CallCanceled {
1449                    room_id: room_id.to_proto(),
1450                },
1451            )
1452            .trace_err();
1453    }
1454
1455    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1456    {
1457        live_kit
1458            .room_token(
1459                &joined_room.room.livekit_room,
1460                &session.user_id().to_string(),
1461            )
1462            .trace_err()
1463            .map(|token| proto::LiveKitConnectionInfo {
1464                server_url: live_kit.url().into(),
1465                token,
1466                can_publish: true,
1467            })
1468    } else {
1469        None
1470    };
1471
1472    response.send(proto::JoinRoomResponse {
1473        room: Some(joined_room.room),
1474        channel_id: None,
1475        live_kit_connection_info,
1476    })?;
1477
1478    update_user_contacts(session.user_id(), &session).await?;
1479    Ok(())
1480}
1481
1482/// Rejoin room is used to reconnect to a room after connection errors.
1483async fn rejoin_room(
1484    request: proto::RejoinRoom,
1485    response: Response<proto::RejoinRoom>,
1486    session: MessageContext,
1487) -> Result<()> {
1488    let room;
1489    let channel;
1490    {
1491        let mut rejoined_room = session
1492            .db()
1493            .await
1494            .rejoin_room(request, session.user_id(), session.connection_id)
1495            .await?;
1496
1497        response.send(proto::RejoinRoomResponse {
1498            room: Some(rejoined_room.room.clone()),
1499            reshared_projects: rejoined_room
1500                .reshared_projects
1501                .iter()
1502                .map(|project| proto::ResharedProject {
1503                    id: project.id.to_proto(),
1504                    collaborators: project
1505                        .collaborators
1506                        .iter()
1507                        .map(|collaborator| collaborator.to_proto())
1508                        .collect(),
1509                })
1510                .collect(),
1511            rejoined_projects: rejoined_room
1512                .rejoined_projects
1513                .iter()
1514                .map(|rejoined_project| rejoined_project.to_proto())
1515                .collect(),
1516        })?;
1517        room_updated(&rejoined_room.room, &session.peer);
1518
1519        for project in &rejoined_room.reshared_projects {
1520            for collaborator in &project.collaborators {
1521                session
1522                    .peer
1523                    .send(
1524                        collaborator.connection_id,
1525                        proto::UpdateProjectCollaborator {
1526                            project_id: project.id.to_proto(),
1527                            old_peer_id: Some(project.old_connection_id.into()),
1528                            new_peer_id: Some(session.connection_id.into()),
1529                        },
1530                    )
1531                    .trace_err();
1532            }
1533
1534            broadcast(
1535                Some(session.connection_id),
1536                project
1537                    .collaborators
1538                    .iter()
1539                    .map(|collaborator| collaborator.connection_id),
1540                |connection_id| {
1541                    session.peer.forward_send(
1542                        session.connection_id,
1543                        connection_id,
1544                        proto::UpdateProject {
1545                            project_id: project.id.to_proto(),
1546                            worktrees: project.worktrees.clone(),
1547                        },
1548                    )
1549                },
1550            );
1551        }
1552
1553        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1554
1555        let rejoined_room = rejoined_room.into_inner();
1556
1557        room = rejoined_room.room;
1558        channel = rejoined_room.channel;
1559    }
1560
1561    if let Some(channel) = channel {
1562        channel_updated(
1563            &channel,
1564            &room,
1565            &session.peer,
1566            &*session.connection_pool().await,
1567        );
1568    }
1569
1570    update_user_contacts(session.user_id(), &session).await?;
1571    Ok(())
1572}
1573
1574fn notify_rejoined_projects(
1575    rejoined_projects: &mut Vec<RejoinedProject>,
1576    session: &Session,
1577) -> Result<()> {
1578    for project in rejoined_projects.iter() {
1579        for collaborator in &project.collaborators {
1580            session
1581                .peer
1582                .send(
1583                    collaborator.connection_id,
1584                    proto::UpdateProjectCollaborator {
1585                        project_id: project.id.to_proto(),
1586                        old_peer_id: Some(project.old_connection_id.into()),
1587                        new_peer_id: Some(session.connection_id.into()),
1588                    },
1589                )
1590                .trace_err();
1591        }
1592    }
1593
1594    for project in rejoined_projects {
1595        for worktree in mem::take(&mut project.worktrees) {
1596            // Stream this worktree's entries.
1597            let message = proto::UpdateWorktree {
1598                project_id: project.id.to_proto(),
1599                worktree_id: worktree.id,
1600                abs_path: worktree.abs_path.clone(),
1601                root_name: worktree.root_name,
1602                updated_entries: worktree.updated_entries,
1603                removed_entries: worktree.removed_entries,
1604                scan_id: worktree.scan_id,
1605                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1606                updated_repositories: worktree.updated_repositories,
1607                removed_repositories: worktree.removed_repositories,
1608            };
1609            for update in proto::split_worktree_update(message) {
1610                session.peer.send(session.connection_id, update)?;
1611            }
1612
1613            // Stream this worktree's diagnostics.
1614            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1615            if let Some(summary) = worktree_diagnostics.next() {
1616                let message = proto::UpdateDiagnosticSummary {
1617                    project_id: project.id.to_proto(),
1618                    worktree_id: worktree.id,
1619                    summary: Some(summary),
1620                    more_summaries: worktree_diagnostics.collect(),
1621                };
1622                session.peer.send(session.connection_id, message)?;
1623            }
1624
1625            for settings_file in worktree.settings_files {
1626                session.peer.send(
1627                    session.connection_id,
1628                    proto::UpdateWorktreeSettings {
1629                        project_id: project.id.to_proto(),
1630                        worktree_id: worktree.id,
1631                        path: settings_file.path,
1632                        content: Some(settings_file.content),
1633                        kind: Some(settings_file.kind.to_proto().into()),
1634                    },
1635                )?;
1636            }
1637        }
1638
1639        for repository in mem::take(&mut project.updated_repositories) {
1640            for update in split_repository_update(repository) {
1641                session.peer.send(session.connection_id, update)?;
1642            }
1643        }
1644
1645        for id in mem::take(&mut project.removed_repositories) {
1646            session.peer.send(
1647                session.connection_id,
1648                proto::RemoveRepository {
1649                    project_id: project.id.to_proto(),
1650                    id,
1651                },
1652            )?;
1653        }
1654    }
1655
1656    Ok(())
1657}
1658
1659/// leave room disconnects from the room.
1660async fn leave_room(
1661    _: proto::LeaveRoom,
1662    response: Response<proto::LeaveRoom>,
1663    session: MessageContext,
1664) -> Result<()> {
1665    leave_room_for_session(&session, session.connection_id).await?;
1666    response.send(proto::Ack {})?;
1667    Ok(())
1668}
1669
1670/// Updates the permissions of someone else in the room.
1671async fn set_room_participant_role(
1672    request: proto::SetRoomParticipantRole,
1673    response: Response<proto::SetRoomParticipantRole>,
1674    session: MessageContext,
1675) -> Result<()> {
1676    let user_id = UserId::from_proto(request.user_id);
1677    let role = ChannelRole::from(request.role());
1678
1679    let (livekit_room, can_publish) = {
1680        let room = session
1681            .db()
1682            .await
1683            .set_room_participant_role(
1684                session.user_id(),
1685                RoomId::from_proto(request.room_id),
1686                user_id,
1687                role,
1688            )
1689            .await?;
1690
1691        let livekit_room = room.livekit_room.clone();
1692        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1693        room_updated(&room, &session.peer);
1694        (livekit_room, can_publish)
1695    };
1696
1697    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1698        live_kit
1699            .update_participant(
1700                livekit_room.clone(),
1701                request.user_id.to_string(),
1702                livekit_api::proto::ParticipantPermission {
1703                    can_subscribe: true,
1704                    can_publish,
1705                    can_publish_data: can_publish,
1706                    hidden: false,
1707                    recorder: false,
1708                },
1709            )
1710            .await
1711            .trace_err();
1712    }
1713
1714    response.send(proto::Ack {})?;
1715    Ok(())
1716}
1717
1718/// Call someone else into the current room
1719async fn call(
1720    request: proto::Call,
1721    response: Response<proto::Call>,
1722    session: MessageContext,
1723) -> Result<()> {
1724    let room_id = RoomId::from_proto(request.room_id);
1725    let calling_user_id = session.user_id();
1726    let calling_connection_id = session.connection_id;
1727    let called_user_id = UserId::from_proto(request.called_user_id);
1728    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1729    if !session
1730        .db()
1731        .await
1732        .has_contact(calling_user_id, called_user_id)
1733        .await?
1734    {
1735        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1736    }
1737
1738    let incoming_call = {
1739        let (room, incoming_call) = &mut *session
1740            .db()
1741            .await
1742            .call(
1743                room_id,
1744                calling_user_id,
1745                calling_connection_id,
1746                called_user_id,
1747                initial_project_id,
1748            )
1749            .await?;
1750        room_updated(room, &session.peer);
1751        mem::take(incoming_call)
1752    };
1753    update_user_contacts(called_user_id, &session).await?;
1754
1755    let mut calls = session
1756        .connection_pool()
1757        .await
1758        .user_connection_ids(called_user_id)
1759        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1760        .collect::<FuturesUnordered<_>>();
1761
1762    while let Some(call_response) = calls.next().await {
1763        match call_response.as_ref() {
1764            Ok(_) => {
1765                response.send(proto::Ack {})?;
1766                return Ok(());
1767            }
1768            Err(_) => {
1769                call_response.trace_err();
1770            }
1771        }
1772    }
1773
1774    {
1775        let room = session
1776            .db()
1777            .await
1778            .call_failed(room_id, called_user_id)
1779            .await?;
1780        room_updated(&room, &session.peer);
1781    }
1782    update_user_contacts(called_user_id, &session).await?;
1783
1784    Err(anyhow!("failed to ring user"))?
1785}
1786
1787/// Cancel an outgoing call.
1788async fn cancel_call(
1789    request: proto::CancelCall,
1790    response: Response<proto::CancelCall>,
1791    session: MessageContext,
1792) -> Result<()> {
1793    let called_user_id = UserId::from_proto(request.called_user_id);
1794    let room_id = RoomId::from_proto(request.room_id);
1795    {
1796        let room = session
1797            .db()
1798            .await
1799            .cancel_call(room_id, session.connection_id, called_user_id)
1800            .await?;
1801        room_updated(&room, &session.peer);
1802    }
1803
1804    for connection_id in session
1805        .connection_pool()
1806        .await
1807        .user_connection_ids(called_user_id)
1808    {
1809        session
1810            .peer
1811            .send(
1812                connection_id,
1813                proto::CallCanceled {
1814                    room_id: room_id.to_proto(),
1815                },
1816            )
1817            .trace_err();
1818    }
1819    response.send(proto::Ack {})?;
1820
1821    update_user_contacts(called_user_id, &session).await?;
1822    Ok(())
1823}
1824
1825/// Decline an incoming call.
1826async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1827    let room_id = RoomId::from_proto(message.room_id);
1828    {
1829        let room = session
1830            .db()
1831            .await
1832            .decline_call(Some(room_id), session.user_id())
1833            .await?
1834            .context("declining call")?;
1835        room_updated(&room, &session.peer);
1836    }
1837
1838    for connection_id in session
1839        .connection_pool()
1840        .await
1841        .user_connection_ids(session.user_id())
1842    {
1843        session
1844            .peer
1845            .send(
1846                connection_id,
1847                proto::CallCanceled {
1848                    room_id: room_id.to_proto(),
1849                },
1850            )
1851            .trace_err();
1852    }
1853    update_user_contacts(session.user_id(), &session).await?;
1854    Ok(())
1855}
1856
1857/// Updates other participants in the room with your current location.
1858async fn update_participant_location(
1859    request: proto::UpdateParticipantLocation,
1860    response: Response<proto::UpdateParticipantLocation>,
1861    session: MessageContext,
1862) -> Result<()> {
1863    let room_id = RoomId::from_proto(request.room_id);
1864    let location = request.location.context("invalid location")?;
1865
1866    let db = session.db().await;
1867    let room = db
1868        .update_room_participant_location(room_id, session.connection_id, location)
1869        .await?;
1870
1871    room_updated(&room, &session.peer);
1872    response.send(proto::Ack {})?;
1873    Ok(())
1874}
1875
1876/// Share a project into the room.
1877async fn share_project(
1878    request: proto::ShareProject,
1879    response: Response<proto::ShareProject>,
1880    session: MessageContext,
1881) -> Result<()> {
1882    let (project_id, room) = &*session
1883        .db()
1884        .await
1885        .share_project(
1886            RoomId::from_proto(request.room_id),
1887            session.connection_id,
1888            &request.worktrees,
1889            request.is_ssh_project,
1890            request.windows_paths.unwrap_or(false),
1891        )
1892        .await?;
1893    response.send(proto::ShareProjectResponse {
1894        project_id: project_id.to_proto(),
1895    })?;
1896    room_updated(room, &session.peer);
1897
1898    Ok(())
1899}
1900
1901/// Unshare a project from the room.
1902async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1903    let project_id = ProjectId::from_proto(message.project_id);
1904    unshare_project_internal(project_id, session.connection_id, &session).await
1905}
1906
1907async fn unshare_project_internal(
1908    project_id: ProjectId,
1909    connection_id: ConnectionId,
1910    session: &Session,
1911) -> Result<()> {
1912    let delete = {
1913        let room_guard = session
1914            .db()
1915            .await
1916            .unshare_project(project_id, connection_id)
1917            .await?;
1918
1919        let (delete, room, guest_connection_ids) = &*room_guard;
1920
1921        let message = proto::UnshareProject {
1922            project_id: project_id.to_proto(),
1923        };
1924
1925        broadcast(
1926            Some(connection_id),
1927            guest_connection_ids.iter().copied(),
1928            |conn_id| session.peer.send(conn_id, message.clone()),
1929        );
1930        if let Some(room) = room {
1931            room_updated(room, &session.peer);
1932        }
1933
1934        *delete
1935    };
1936
1937    if delete {
1938        let db = session.db().await;
1939        db.delete_project(project_id).await?;
1940    }
1941
1942    Ok(())
1943}
1944
1945/// Join someone elses shared project.
1946async fn join_project(
1947    request: proto::JoinProject,
1948    response: Response<proto::JoinProject>,
1949    session: MessageContext,
1950) -> Result<()> {
1951    let project_id = ProjectId::from_proto(request.project_id);
1952
1953    tracing::info!(%project_id, "join project");
1954
1955    let db = session.db().await;
1956    let (project, replica_id) = &mut *db
1957        .join_project(
1958            project_id,
1959            session.connection_id,
1960            session.user_id(),
1961            request.committer_name.clone(),
1962            request.committer_email.clone(),
1963        )
1964        .await?;
1965    drop(db);
1966    tracing::info!(%project_id, "join remote project");
1967    let collaborators = project
1968        .collaborators
1969        .iter()
1970        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1971        .map(|collaborator| collaborator.to_proto())
1972        .collect::<Vec<_>>();
1973    let project_id = project.id;
1974    let guest_user_id = session.user_id();
1975
1976    let worktrees = project
1977        .worktrees
1978        .iter()
1979        .map(|(id, worktree)| proto::WorktreeMetadata {
1980            id: *id,
1981            root_name: worktree.root_name.clone(),
1982            visible: worktree.visible,
1983            abs_path: worktree.abs_path.clone(),
1984        })
1985        .collect::<Vec<_>>();
1986
1987    let add_project_collaborator = proto::AddProjectCollaborator {
1988        project_id: project_id.to_proto(),
1989        collaborator: Some(proto::Collaborator {
1990            peer_id: Some(session.connection_id.into()),
1991            replica_id: replica_id.0 as u32,
1992            user_id: guest_user_id.to_proto(),
1993            is_host: false,
1994            committer_name: request.committer_name.clone(),
1995            committer_email: request.committer_email.clone(),
1996        }),
1997    };
1998
1999    for collaborator in &collaborators {
2000        session
2001            .peer
2002            .send(
2003                collaborator.peer_id.unwrap().into(),
2004                add_project_collaborator.clone(),
2005            )
2006            .trace_err();
2007    }
2008
2009    // First, we send the metadata associated with each worktree.
2010    let (language_servers, language_server_capabilities) = project
2011        .language_servers
2012        .clone()
2013        .into_iter()
2014        .map(|server| (server.server, server.capabilities))
2015        .unzip();
2016    response.send(proto::JoinProjectResponse {
2017        project_id: project.id.0 as u64,
2018        worktrees,
2019        replica_id: replica_id.0 as u32,
2020        collaborators,
2021        language_servers,
2022        language_server_capabilities,
2023        role: project.role.into(),
2024        windows_paths: project.path_style == PathStyle::Windows,
2025    })?;
2026
2027    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2028        // Stream this worktree's entries.
2029        let message = proto::UpdateWorktree {
2030            project_id: project_id.to_proto(),
2031            worktree_id,
2032            abs_path: worktree.abs_path.clone(),
2033            root_name: worktree.root_name,
2034            updated_entries: worktree.entries,
2035            removed_entries: Default::default(),
2036            scan_id: worktree.scan_id,
2037            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2038            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2039            removed_repositories: Default::default(),
2040        };
2041        for update in proto::split_worktree_update(message) {
2042            session.peer.send(session.connection_id, update.clone())?;
2043        }
2044
2045        // Stream this worktree's diagnostics.
2046        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2047        if let Some(summary) = worktree_diagnostics.next() {
2048            let message = proto::UpdateDiagnosticSummary {
2049                project_id: project.id.to_proto(),
2050                worktree_id: worktree.id,
2051                summary: Some(summary),
2052                more_summaries: worktree_diagnostics.collect(),
2053            };
2054            session.peer.send(session.connection_id, message)?;
2055        }
2056
2057        for settings_file in worktree.settings_files {
2058            session.peer.send(
2059                session.connection_id,
2060                proto::UpdateWorktreeSettings {
2061                    project_id: project_id.to_proto(),
2062                    worktree_id: worktree.id,
2063                    path: settings_file.path,
2064                    content: Some(settings_file.content),
2065                    kind: Some(settings_file.kind.to_proto() as i32),
2066                },
2067            )?;
2068        }
2069    }
2070
2071    for repository in mem::take(&mut project.repositories) {
2072        for update in split_repository_update(repository) {
2073            session.peer.send(session.connection_id, update)?;
2074        }
2075    }
2076
2077    for language_server in &project.language_servers {
2078        session.peer.send(
2079            session.connection_id,
2080            proto::UpdateLanguageServer {
2081                project_id: project_id.to_proto(),
2082                server_name: Some(language_server.server.name.clone()),
2083                language_server_id: language_server.server.id,
2084                variant: Some(
2085                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2086                        proto::LspDiskBasedDiagnosticsUpdated {},
2087                    ),
2088                ),
2089            },
2090        )?;
2091    }
2092
2093    Ok(())
2094}
2095
2096/// Leave someone elses shared project.
2097async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2098    let sender_id = session.connection_id;
2099    let project_id = ProjectId::from_proto(request.project_id);
2100    let db = session.db().await;
2101
2102    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2103    tracing::info!(
2104        %project_id,
2105        "leave project"
2106    );
2107
2108    project_left(project, &session);
2109    if let Some(room) = room {
2110        room_updated(room, &session.peer);
2111    }
2112
2113    Ok(())
2114}
2115
2116/// Updates other participants with changes to the project
2117async fn update_project(
2118    request: proto::UpdateProject,
2119    response: Response<proto::UpdateProject>,
2120    session: MessageContext,
2121) -> Result<()> {
2122    let project_id = ProjectId::from_proto(request.project_id);
2123    let (room, guest_connection_ids) = &*session
2124        .db()
2125        .await
2126        .update_project(project_id, session.connection_id, &request.worktrees)
2127        .await?;
2128    broadcast(
2129        Some(session.connection_id),
2130        guest_connection_ids.iter().copied(),
2131        |connection_id| {
2132            session
2133                .peer
2134                .forward_send(session.connection_id, connection_id, request.clone())
2135        },
2136    );
2137    if let Some(room) = room {
2138        room_updated(room, &session.peer);
2139    }
2140    response.send(proto::Ack {})?;
2141
2142    Ok(())
2143}
2144
2145/// Updates other participants with changes to the worktree
2146async fn update_worktree(
2147    request: proto::UpdateWorktree,
2148    response: Response<proto::UpdateWorktree>,
2149    session: MessageContext,
2150) -> Result<()> {
2151    let guest_connection_ids = session
2152        .db()
2153        .await
2154        .update_worktree(&request, session.connection_id)
2155        .await?;
2156
2157    broadcast(
2158        Some(session.connection_id),
2159        guest_connection_ids.iter().copied(),
2160        |connection_id| {
2161            session
2162                .peer
2163                .forward_send(session.connection_id, connection_id, request.clone())
2164        },
2165    );
2166    response.send(proto::Ack {})?;
2167    Ok(())
2168}
2169
2170async fn update_repository(
2171    request: proto::UpdateRepository,
2172    response: Response<proto::UpdateRepository>,
2173    session: MessageContext,
2174) -> Result<()> {
2175    let guest_connection_ids = session
2176        .db()
2177        .await
2178        .update_repository(&request, session.connection_id)
2179        .await?;
2180
2181    broadcast(
2182        Some(session.connection_id),
2183        guest_connection_ids.iter().copied(),
2184        |connection_id| {
2185            session
2186                .peer
2187                .forward_send(session.connection_id, connection_id, request.clone())
2188        },
2189    );
2190    response.send(proto::Ack {})?;
2191    Ok(())
2192}
2193
2194async fn remove_repository(
2195    request: proto::RemoveRepository,
2196    response: Response<proto::RemoveRepository>,
2197    session: MessageContext,
2198) -> Result<()> {
2199    let guest_connection_ids = session
2200        .db()
2201        .await
2202        .remove_repository(&request, session.connection_id)
2203        .await?;
2204
2205    broadcast(
2206        Some(session.connection_id),
2207        guest_connection_ids.iter().copied(),
2208        |connection_id| {
2209            session
2210                .peer
2211                .forward_send(session.connection_id, connection_id, request.clone())
2212        },
2213    );
2214    response.send(proto::Ack {})?;
2215    Ok(())
2216}
2217
2218/// Updates other participants with changes to the diagnostics
2219async fn update_diagnostic_summary(
2220    message: proto::UpdateDiagnosticSummary,
2221    session: MessageContext,
2222) -> Result<()> {
2223    let guest_connection_ids = session
2224        .db()
2225        .await
2226        .update_diagnostic_summary(&message, session.connection_id)
2227        .await?;
2228
2229    broadcast(
2230        Some(session.connection_id),
2231        guest_connection_ids.iter().copied(),
2232        |connection_id| {
2233            session
2234                .peer
2235                .forward_send(session.connection_id, connection_id, message.clone())
2236        },
2237    );
2238
2239    Ok(())
2240}
2241
2242/// Updates other participants with changes to the worktree settings
2243async fn update_worktree_settings(
2244    message: proto::UpdateWorktreeSettings,
2245    session: MessageContext,
2246) -> Result<()> {
2247    let guest_connection_ids = session
2248        .db()
2249        .await
2250        .update_worktree_settings(&message, session.connection_id)
2251        .await?;
2252
2253    broadcast(
2254        Some(session.connection_id),
2255        guest_connection_ids.iter().copied(),
2256        |connection_id| {
2257            session
2258                .peer
2259                .forward_send(session.connection_id, connection_id, message.clone())
2260        },
2261    );
2262
2263    Ok(())
2264}
2265
2266/// Notify other participants that a language server has started.
2267async fn start_language_server(
2268    request: proto::StartLanguageServer,
2269    session: MessageContext,
2270) -> Result<()> {
2271    let guest_connection_ids = session
2272        .db()
2273        .await
2274        .start_language_server(&request, session.connection_id)
2275        .await?;
2276
2277    broadcast(
2278        Some(session.connection_id),
2279        guest_connection_ids.iter().copied(),
2280        |connection_id| {
2281            session
2282                .peer
2283                .forward_send(session.connection_id, connection_id, request.clone())
2284        },
2285    );
2286    Ok(())
2287}
2288
2289/// Notify other participants that a language server has changed.
2290async fn update_language_server(
2291    request: proto::UpdateLanguageServer,
2292    session: MessageContext,
2293) -> Result<()> {
2294    let project_id = ProjectId::from_proto(request.project_id);
2295    let db = session.db().await;
2296
2297    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2298        && let Some(capabilities) = update.capabilities.clone()
2299    {
2300        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2301            .await?;
2302    }
2303
2304    let project_connection_ids = db
2305        .project_connection_ids(project_id, session.connection_id, true)
2306        .await?;
2307    broadcast(
2308        Some(session.connection_id),
2309        project_connection_ids.iter().copied(),
2310        |connection_id| {
2311            session
2312                .peer
2313                .forward_send(session.connection_id, connection_id, request.clone())
2314        },
2315    );
2316    Ok(())
2317}
2318
2319/// forward a project request to the host. These requests should be read only
2320/// as guests are allowed to send them.
2321async fn forward_read_only_project_request<T>(
2322    request: T,
2323    response: Response<T>,
2324    session: MessageContext,
2325) -> Result<()>
2326where
2327    T: EntityMessage + RequestMessage,
2328{
2329    let project_id = ProjectId::from_proto(request.remote_entity_id());
2330    let host_connection_id = session
2331        .db()
2332        .await
2333        .host_for_read_only_project_request(project_id, session.connection_id)
2334        .await?;
2335    let payload = session.forward_request(host_connection_id, request).await?;
2336    response.send(payload)?;
2337    Ok(())
2338}
2339
2340/// forward a project request to the host. These requests are disallowed
2341/// for guests.
2342async fn forward_mutating_project_request<T>(
2343    request: T,
2344    response: Response<T>,
2345    session: MessageContext,
2346) -> Result<()>
2347where
2348    T: EntityMessage + RequestMessage,
2349{
2350    let project_id = ProjectId::from_proto(request.remote_entity_id());
2351
2352    let host_connection_id = session
2353        .db()
2354        .await
2355        .host_for_mutating_project_request(project_id, session.connection_id)
2356        .await?;
2357    let payload = session.forward_request(host_connection_id, request).await?;
2358    response.send(payload)?;
2359    Ok(())
2360}
2361
2362async fn lsp_query(
2363    request: proto::LspQuery,
2364    response: Response<proto::LspQuery>,
2365    session: MessageContext,
2366) -> Result<()> {
2367    let (name, should_write) = request.query_name_and_write_permissions();
2368    tracing::Span::current().record("lsp_query_request", name);
2369    tracing::info!("lsp_query message received");
2370    if should_write {
2371        forward_mutating_project_request(request, response, session).await
2372    } else {
2373        forward_read_only_project_request(request, response, session).await
2374    }
2375}
2376
2377/// Notify other participants that a new buffer has been created
2378async fn create_buffer_for_peer(
2379    request: proto::CreateBufferForPeer,
2380    session: MessageContext,
2381) -> Result<()> {
2382    session
2383        .db()
2384        .await
2385        .check_user_is_project_host(
2386            ProjectId::from_proto(request.project_id),
2387            session.connection_id,
2388        )
2389        .await?;
2390    let peer_id = request.peer_id.context("invalid peer id")?;
2391    session
2392        .peer
2393        .forward_send(session.connection_id, peer_id.into(), request)?;
2394    Ok(())
2395}
2396
2397/// Notify other participants that a new image has been created
2398async fn create_image_for_peer(
2399    request: proto::CreateImageForPeer,
2400    session: MessageContext,
2401) -> Result<()> {
2402    session
2403        .db()
2404        .await
2405        .check_user_is_project_host(
2406            ProjectId::from_proto(request.project_id),
2407            session.connection_id,
2408        )
2409        .await?;
2410    let peer_id = request.peer_id.context("invalid peer id")?;
2411    session
2412        .peer
2413        .forward_send(session.connection_id, peer_id.into(), request)?;
2414    Ok(())
2415}
2416
2417/// Notify other participants that a buffer has been updated. This is
2418/// allowed for guests as long as the update is limited to selections.
2419async fn update_buffer(
2420    request: proto::UpdateBuffer,
2421    response: Response<proto::UpdateBuffer>,
2422    session: MessageContext,
2423) -> Result<()> {
2424    let project_id = ProjectId::from_proto(request.project_id);
2425    let mut capability = Capability::ReadOnly;
2426
2427    for op in request.operations.iter() {
2428        match op.variant {
2429            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2430            Some(_) => capability = Capability::ReadWrite,
2431        }
2432    }
2433
2434    let host = {
2435        let guard = session
2436            .db()
2437            .await
2438            .connections_for_buffer_update(project_id, session.connection_id, capability)
2439            .await?;
2440
2441        let (host, guests) = &*guard;
2442
2443        broadcast(
2444            Some(session.connection_id),
2445            guests.clone(),
2446            |connection_id| {
2447                session
2448                    .peer
2449                    .forward_send(session.connection_id, connection_id, request.clone())
2450            },
2451        );
2452
2453        *host
2454    };
2455
2456    if host != session.connection_id {
2457        session.forward_request(host, request.clone()).await?;
2458    }
2459
2460    response.send(proto::Ack {})?;
2461    Ok(())
2462}
2463
2464async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2465    let project_id = ProjectId::from_proto(message.project_id);
2466
2467    let operation = message.operation.as_ref().context("invalid operation")?;
2468    let capability = match operation.variant.as_ref() {
2469        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2470            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2471                match buffer_op.variant {
2472                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2473                        Capability::ReadOnly
2474                    }
2475                    _ => Capability::ReadWrite,
2476                }
2477            } else {
2478                Capability::ReadWrite
2479            }
2480        }
2481        Some(_) => Capability::ReadWrite,
2482        None => Capability::ReadOnly,
2483    };
2484
2485    let guard = session
2486        .db()
2487        .await
2488        .connections_for_buffer_update(project_id, session.connection_id, capability)
2489        .await?;
2490
2491    let (host, guests) = &*guard;
2492
2493    broadcast(
2494        Some(session.connection_id),
2495        guests.iter().chain([host]).copied(),
2496        |connection_id| {
2497            session
2498                .peer
2499                .forward_send(session.connection_id, connection_id, message.clone())
2500        },
2501    );
2502
2503    Ok(())
2504}
2505
2506/// Notify other participants that a project has been updated.
2507async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2508    request: T,
2509    session: MessageContext,
2510) -> Result<()> {
2511    let project_id = ProjectId::from_proto(request.remote_entity_id());
2512    let project_connection_ids = session
2513        .db()
2514        .await
2515        .project_connection_ids(project_id, session.connection_id, false)
2516        .await?;
2517
2518    broadcast(
2519        Some(session.connection_id),
2520        project_connection_ids.iter().copied(),
2521        |connection_id| {
2522            session
2523                .peer
2524                .forward_send(session.connection_id, connection_id, request.clone())
2525        },
2526    );
2527    Ok(())
2528}
2529
2530/// Start following another user in a call.
2531async fn follow(
2532    request: proto::Follow,
2533    response: Response<proto::Follow>,
2534    session: MessageContext,
2535) -> Result<()> {
2536    let room_id = RoomId::from_proto(request.room_id);
2537    let project_id = request.project_id.map(ProjectId::from_proto);
2538    let leader_id = request.leader_id.context("invalid leader id")?.into();
2539    let follower_id = session.connection_id;
2540
2541    session
2542        .db()
2543        .await
2544        .check_room_participants(room_id, leader_id, session.connection_id)
2545        .await?;
2546
2547    let response_payload = session.forward_request(leader_id, request).await?;
2548    response.send(response_payload)?;
2549
2550    if let Some(project_id) = project_id {
2551        let room = session
2552            .db()
2553            .await
2554            .follow(room_id, project_id, leader_id, follower_id)
2555            .await?;
2556        room_updated(&room, &session.peer);
2557    }
2558
2559    Ok(())
2560}
2561
2562/// Stop following another user in a call.
2563async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2564    let room_id = RoomId::from_proto(request.room_id);
2565    let project_id = request.project_id.map(ProjectId::from_proto);
2566    let leader_id = request.leader_id.context("invalid leader id")?.into();
2567    let follower_id = session.connection_id;
2568
2569    session
2570        .db()
2571        .await
2572        .check_room_participants(room_id, leader_id, session.connection_id)
2573        .await?;
2574
2575    session
2576        .peer
2577        .forward_send(session.connection_id, leader_id, request)?;
2578
2579    if let Some(project_id) = project_id {
2580        let room = session
2581            .db()
2582            .await
2583            .unfollow(room_id, project_id, leader_id, follower_id)
2584            .await?;
2585        room_updated(&room, &session.peer);
2586    }
2587
2588    Ok(())
2589}
2590
2591/// Notify everyone following you of your current location.
2592async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2593    let room_id = RoomId::from_proto(request.room_id);
2594    let database = session.db.lock().await;
2595
2596    let connection_ids = if let Some(project_id) = request.project_id {
2597        let project_id = ProjectId::from_proto(project_id);
2598        database
2599            .project_connection_ids(project_id, session.connection_id, true)
2600            .await?
2601    } else {
2602        database
2603            .room_connection_ids(room_id, session.connection_id)
2604            .await?
2605    };
2606
2607    // For now, don't send view update messages back to that view's current leader.
2608    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2609        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2610        _ => None,
2611    });
2612
2613    for connection_id in connection_ids.iter().cloned() {
2614        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2615            session
2616                .peer
2617                .forward_send(session.connection_id, connection_id, request.clone())?;
2618        }
2619    }
2620    Ok(())
2621}
2622
2623/// Get public data about users.
2624async fn get_users(
2625    request: proto::GetUsers,
2626    response: Response<proto::GetUsers>,
2627    session: MessageContext,
2628) -> Result<()> {
2629    let user_ids = request
2630        .user_ids
2631        .into_iter()
2632        .map(UserId::from_proto)
2633        .collect();
2634    let users = session
2635        .db()
2636        .await
2637        .get_users_by_ids(user_ids)
2638        .await?
2639        .into_iter()
2640        .map(|user| proto::User {
2641            id: user.id.to_proto(),
2642            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2643            github_login: user.github_login,
2644            name: user.name,
2645        })
2646        .collect();
2647    response.send(proto::UsersResponse { users })?;
2648    Ok(())
2649}
2650
2651/// Search for users (to invite) buy Github login
2652async fn fuzzy_search_users(
2653    request: proto::FuzzySearchUsers,
2654    response: Response<proto::FuzzySearchUsers>,
2655    session: MessageContext,
2656) -> Result<()> {
2657    let query = request.query;
2658    let users = match query.len() {
2659        0 => vec![],
2660        1 | 2 => session
2661            .db()
2662            .await
2663            .get_user_by_github_login(&query)
2664            .await?
2665            .into_iter()
2666            .collect(),
2667        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2668    };
2669    let users = users
2670        .into_iter()
2671        .filter(|user| user.id != session.user_id())
2672        .map(|user| proto::User {
2673            id: user.id.to_proto(),
2674            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2675            github_login: user.github_login,
2676            name: user.name,
2677        })
2678        .collect();
2679    response.send(proto::UsersResponse { users })?;
2680    Ok(())
2681}
2682
2683/// Send a contact request to another user.
2684async fn request_contact(
2685    request: proto::RequestContact,
2686    response: Response<proto::RequestContact>,
2687    session: MessageContext,
2688) -> Result<()> {
2689    let requester_id = session.user_id();
2690    let responder_id = UserId::from_proto(request.responder_id);
2691    if requester_id == responder_id {
2692        return Err(anyhow!("cannot add yourself as a contact"))?;
2693    }
2694
2695    let notifications = session
2696        .db()
2697        .await
2698        .send_contact_request(requester_id, responder_id)
2699        .await?;
2700
2701    // Update outgoing contact requests of requester
2702    let mut update = proto::UpdateContacts::default();
2703    update.outgoing_requests.push(responder_id.to_proto());
2704    for connection_id in session
2705        .connection_pool()
2706        .await
2707        .user_connection_ids(requester_id)
2708    {
2709        session.peer.send(connection_id, update.clone())?;
2710    }
2711
2712    // Update incoming contact requests of responder
2713    let mut update = proto::UpdateContacts::default();
2714    update
2715        .incoming_requests
2716        .push(proto::IncomingContactRequest {
2717            requester_id: requester_id.to_proto(),
2718        });
2719    let connection_pool = session.connection_pool().await;
2720    for connection_id in connection_pool.user_connection_ids(responder_id) {
2721        session.peer.send(connection_id, update.clone())?;
2722    }
2723
2724    send_notifications(&connection_pool, &session.peer, notifications);
2725
2726    response.send(proto::Ack {})?;
2727    Ok(())
2728}
2729
2730/// Accept or decline a contact request
2731async fn respond_to_contact_request(
2732    request: proto::RespondToContactRequest,
2733    response: Response<proto::RespondToContactRequest>,
2734    session: MessageContext,
2735) -> Result<()> {
2736    let responder_id = session.user_id();
2737    let requester_id = UserId::from_proto(request.requester_id);
2738    let db = session.db().await;
2739    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2740        db.dismiss_contact_notification(responder_id, requester_id)
2741            .await?;
2742    } else {
2743        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2744
2745        let notifications = db
2746            .respond_to_contact_request(responder_id, requester_id, accept)
2747            .await?;
2748        let requester_busy = db.is_user_busy(requester_id).await?;
2749        let responder_busy = db.is_user_busy(responder_id).await?;
2750
2751        let pool = session.connection_pool().await;
2752        // Update responder with new contact
2753        let mut update = proto::UpdateContacts::default();
2754        if accept {
2755            update
2756                .contacts
2757                .push(contact_for_user(requester_id, requester_busy, &pool));
2758        }
2759        update
2760            .remove_incoming_requests
2761            .push(requester_id.to_proto());
2762        for connection_id in pool.user_connection_ids(responder_id) {
2763            session.peer.send(connection_id, update.clone())?;
2764        }
2765
2766        // Update requester with new contact
2767        let mut update = proto::UpdateContacts::default();
2768        if accept {
2769            update
2770                .contacts
2771                .push(contact_for_user(responder_id, responder_busy, &pool));
2772        }
2773        update
2774            .remove_outgoing_requests
2775            .push(responder_id.to_proto());
2776
2777        for connection_id in pool.user_connection_ids(requester_id) {
2778            session.peer.send(connection_id, update.clone())?;
2779        }
2780
2781        send_notifications(&pool, &session.peer, notifications);
2782    }
2783
2784    response.send(proto::Ack {})?;
2785    Ok(())
2786}
2787
2788/// Remove a contact.
2789async fn remove_contact(
2790    request: proto::RemoveContact,
2791    response: Response<proto::RemoveContact>,
2792    session: MessageContext,
2793) -> Result<()> {
2794    let requester_id = session.user_id();
2795    let responder_id = UserId::from_proto(request.user_id);
2796    let db = session.db().await;
2797    let (contact_accepted, deleted_notification_id) =
2798        db.remove_contact(requester_id, responder_id).await?;
2799
2800    let pool = session.connection_pool().await;
2801    // Update outgoing contact requests of requester
2802    let mut update = proto::UpdateContacts::default();
2803    if contact_accepted {
2804        update.remove_contacts.push(responder_id.to_proto());
2805    } else {
2806        update
2807            .remove_outgoing_requests
2808            .push(responder_id.to_proto());
2809    }
2810    for connection_id in pool.user_connection_ids(requester_id) {
2811        session.peer.send(connection_id, update.clone())?;
2812    }
2813
2814    // Update incoming contact requests of responder
2815    let mut update = proto::UpdateContacts::default();
2816    if contact_accepted {
2817        update.remove_contacts.push(requester_id.to_proto());
2818    } else {
2819        update
2820            .remove_incoming_requests
2821            .push(requester_id.to_proto());
2822    }
2823    for connection_id in pool.user_connection_ids(responder_id) {
2824        session.peer.send(connection_id, update.clone())?;
2825        if let Some(notification_id) = deleted_notification_id {
2826            session.peer.send(
2827                connection_id,
2828                proto::DeleteNotification {
2829                    notification_id: notification_id.to_proto(),
2830                },
2831            )?;
2832        }
2833    }
2834
2835    response.send(proto::Ack {})?;
2836    Ok(())
2837}
2838
2839fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2840    version.0.minor < 139
2841}
2842
2843async fn subscribe_to_channels(
2844    _: proto::SubscribeToChannels,
2845    session: MessageContext,
2846) -> Result<()> {
2847    subscribe_user_to_channels(session.user_id(), &session).await?;
2848    Ok(())
2849}
2850
2851async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2852    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2853    let mut pool = session.connection_pool().await;
2854    for membership in &channels_for_user.channel_memberships {
2855        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2856    }
2857    session.peer.send(
2858        session.connection_id,
2859        build_update_user_channels(&channels_for_user),
2860    )?;
2861    session.peer.send(
2862        session.connection_id,
2863        build_channels_update(channels_for_user),
2864    )?;
2865    Ok(())
2866}
2867
2868/// Creates a new channel.
2869async fn create_channel(
2870    request: proto::CreateChannel,
2871    response: Response<proto::CreateChannel>,
2872    session: MessageContext,
2873) -> Result<()> {
2874    let db = session.db().await;
2875
2876    let parent_id = request.parent_id.map(ChannelId::from_proto);
2877    let (channel, membership) = db
2878        .create_channel(&request.name, parent_id, session.user_id())
2879        .await?;
2880
2881    let root_id = channel.root_id();
2882    let channel = Channel::from_model(channel);
2883
2884    response.send(proto::CreateChannelResponse {
2885        channel: Some(channel.to_proto()),
2886        parent_id: request.parent_id,
2887    })?;
2888
2889    let mut connection_pool = session.connection_pool().await;
2890    if let Some(membership) = membership {
2891        connection_pool.subscribe_to_channel(
2892            membership.user_id,
2893            membership.channel_id,
2894            membership.role,
2895        );
2896        let update = proto::UpdateUserChannels {
2897            channel_memberships: vec![proto::ChannelMembership {
2898                channel_id: membership.channel_id.to_proto(),
2899                role: membership.role.into(),
2900            }],
2901            ..Default::default()
2902        };
2903        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2904            session.peer.send(connection_id, update.clone())?;
2905        }
2906    }
2907
2908    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2909        if !role.can_see_channel(channel.visibility) {
2910            continue;
2911        }
2912
2913        let update = proto::UpdateChannels {
2914            channels: vec![channel.to_proto()],
2915            ..Default::default()
2916        };
2917        session.peer.send(connection_id, update.clone())?;
2918    }
2919
2920    Ok(())
2921}
2922
2923/// Delete a channel
2924async fn delete_channel(
2925    request: proto::DeleteChannel,
2926    response: Response<proto::DeleteChannel>,
2927    session: MessageContext,
2928) -> Result<()> {
2929    let db = session.db().await;
2930
2931    let channel_id = request.channel_id;
2932    let (root_channel, removed_channels) = db
2933        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2934        .await?;
2935    response.send(proto::Ack {})?;
2936
2937    // Notify members of removed channels
2938    let mut update = proto::UpdateChannels::default();
2939    update
2940        .delete_channels
2941        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2942
2943    let connection_pool = session.connection_pool().await;
2944    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2945        session.peer.send(connection_id, update.clone())?;
2946    }
2947
2948    Ok(())
2949}
2950
2951/// Invite someone to join a channel.
2952async fn invite_channel_member(
2953    request: proto::InviteChannelMember,
2954    response: Response<proto::InviteChannelMember>,
2955    session: MessageContext,
2956) -> Result<()> {
2957    let db = session.db().await;
2958    let channel_id = ChannelId::from_proto(request.channel_id);
2959    let invitee_id = UserId::from_proto(request.user_id);
2960    let InviteMemberResult {
2961        channel,
2962        notifications,
2963    } = db
2964        .invite_channel_member(
2965            channel_id,
2966            invitee_id,
2967            session.user_id(),
2968            request.role().into(),
2969        )
2970        .await?;
2971
2972    let update = proto::UpdateChannels {
2973        channel_invitations: vec![channel.to_proto()],
2974        ..Default::default()
2975    };
2976
2977    let connection_pool = session.connection_pool().await;
2978    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2979        session.peer.send(connection_id, update.clone())?;
2980    }
2981
2982    send_notifications(&connection_pool, &session.peer, notifications);
2983
2984    response.send(proto::Ack {})?;
2985    Ok(())
2986}
2987
2988/// remove someone from a channel
2989async fn remove_channel_member(
2990    request: proto::RemoveChannelMember,
2991    response: Response<proto::RemoveChannelMember>,
2992    session: MessageContext,
2993) -> Result<()> {
2994    let db = session.db().await;
2995    let channel_id = ChannelId::from_proto(request.channel_id);
2996    let member_id = UserId::from_proto(request.user_id);
2997
2998    let RemoveChannelMemberResult {
2999        membership_update,
3000        notification_id,
3001    } = db
3002        .remove_channel_member(channel_id, member_id, session.user_id())
3003        .await?;
3004
3005    let mut connection_pool = session.connection_pool().await;
3006    notify_membership_updated(
3007        &mut connection_pool,
3008        membership_update,
3009        member_id,
3010        &session.peer,
3011    );
3012    for connection_id in connection_pool.user_connection_ids(member_id) {
3013        if let Some(notification_id) = notification_id {
3014            session
3015                .peer
3016                .send(
3017                    connection_id,
3018                    proto::DeleteNotification {
3019                        notification_id: notification_id.to_proto(),
3020                    },
3021                )
3022                .trace_err();
3023        }
3024    }
3025
3026    response.send(proto::Ack {})?;
3027    Ok(())
3028}
3029
3030/// Toggle the channel between public and private.
3031/// Care is taken to maintain the invariant that public channels only descend from public channels,
3032/// (though members-only channels can appear at any point in the hierarchy).
3033async fn set_channel_visibility(
3034    request: proto::SetChannelVisibility,
3035    response: Response<proto::SetChannelVisibility>,
3036    session: MessageContext,
3037) -> Result<()> {
3038    let db = session.db().await;
3039    let channel_id = ChannelId::from_proto(request.channel_id);
3040    let visibility = request.visibility().into();
3041
3042    let channel_model = db
3043        .set_channel_visibility(channel_id, visibility, session.user_id())
3044        .await?;
3045    let root_id = channel_model.root_id();
3046    let channel = Channel::from_model(channel_model);
3047
3048    let mut connection_pool = session.connection_pool().await;
3049    for (user_id, role) in connection_pool
3050        .channel_user_ids(root_id)
3051        .collect::<Vec<_>>()
3052        .into_iter()
3053    {
3054        let update = if role.can_see_channel(channel.visibility) {
3055            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3056            proto::UpdateChannels {
3057                channels: vec![channel.to_proto()],
3058                ..Default::default()
3059            }
3060        } else {
3061            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3062            proto::UpdateChannels {
3063                delete_channels: vec![channel.id.to_proto()],
3064                ..Default::default()
3065            }
3066        };
3067
3068        for connection_id in connection_pool.user_connection_ids(user_id) {
3069            session.peer.send(connection_id, update.clone())?;
3070        }
3071    }
3072
3073    response.send(proto::Ack {})?;
3074    Ok(())
3075}
3076
3077/// Alter the role for a user in the channel.
3078async fn set_channel_member_role(
3079    request: proto::SetChannelMemberRole,
3080    response: Response<proto::SetChannelMemberRole>,
3081    session: MessageContext,
3082) -> Result<()> {
3083    let db = session.db().await;
3084    let channel_id = ChannelId::from_proto(request.channel_id);
3085    let member_id = UserId::from_proto(request.user_id);
3086    let result = db
3087        .set_channel_member_role(
3088            channel_id,
3089            session.user_id(),
3090            member_id,
3091            request.role().into(),
3092        )
3093        .await?;
3094
3095    match result {
3096        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3097            let mut connection_pool = session.connection_pool().await;
3098            notify_membership_updated(
3099                &mut connection_pool,
3100                membership_update,
3101                member_id,
3102                &session.peer,
3103            )
3104        }
3105        db::SetMemberRoleResult::InviteUpdated(channel) => {
3106            let update = proto::UpdateChannels {
3107                channel_invitations: vec![channel.to_proto()],
3108                ..Default::default()
3109            };
3110
3111            for connection_id in session
3112                .connection_pool()
3113                .await
3114                .user_connection_ids(member_id)
3115            {
3116                session.peer.send(connection_id, update.clone())?;
3117            }
3118        }
3119    }
3120
3121    response.send(proto::Ack {})?;
3122    Ok(())
3123}
3124
3125/// Change the name of a channel
3126async fn rename_channel(
3127    request: proto::RenameChannel,
3128    response: Response<proto::RenameChannel>,
3129    session: MessageContext,
3130) -> Result<()> {
3131    let db = session.db().await;
3132    let channel_id = ChannelId::from_proto(request.channel_id);
3133    let channel_model = db
3134        .rename_channel(channel_id, session.user_id(), &request.name)
3135        .await?;
3136    let root_id = channel_model.root_id();
3137    let channel = Channel::from_model(channel_model);
3138
3139    response.send(proto::RenameChannelResponse {
3140        channel: Some(channel.to_proto()),
3141    })?;
3142
3143    let connection_pool = session.connection_pool().await;
3144    let update = proto::UpdateChannels {
3145        channels: vec![channel.to_proto()],
3146        ..Default::default()
3147    };
3148    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3149        if role.can_see_channel(channel.visibility) {
3150            session.peer.send(connection_id, update.clone())?;
3151        }
3152    }
3153
3154    Ok(())
3155}
3156
3157/// Move a channel to a new parent.
3158async fn move_channel(
3159    request: proto::MoveChannel,
3160    response: Response<proto::MoveChannel>,
3161    session: MessageContext,
3162) -> Result<()> {
3163    let channel_id = ChannelId::from_proto(request.channel_id);
3164    let to = ChannelId::from_proto(request.to);
3165
3166    let (root_id, channels) = session
3167        .db()
3168        .await
3169        .move_channel(channel_id, to, session.user_id())
3170        .await?;
3171
3172    let connection_pool = session.connection_pool().await;
3173    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3174        let channels = channels
3175            .iter()
3176            .filter_map(|channel| {
3177                if role.can_see_channel(channel.visibility) {
3178                    Some(channel.to_proto())
3179                } else {
3180                    None
3181                }
3182            })
3183            .collect::<Vec<_>>();
3184        if channels.is_empty() {
3185            continue;
3186        }
3187
3188        let update = proto::UpdateChannels {
3189            channels,
3190            ..Default::default()
3191        };
3192
3193        session.peer.send(connection_id, update.clone())?;
3194    }
3195
3196    response.send(Ack {})?;
3197    Ok(())
3198}
3199
3200async fn reorder_channel(
3201    request: proto::ReorderChannel,
3202    response: Response<proto::ReorderChannel>,
3203    session: MessageContext,
3204) -> Result<()> {
3205    let channel_id = ChannelId::from_proto(request.channel_id);
3206    let direction = request.direction();
3207
3208    let updated_channels = session
3209        .db()
3210        .await
3211        .reorder_channel(channel_id, direction, session.user_id())
3212        .await?;
3213
3214    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3215        let connection_pool = session.connection_pool().await;
3216        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3217            let channels = updated_channels
3218                .iter()
3219                .filter_map(|channel| {
3220                    if role.can_see_channel(channel.visibility) {
3221                        Some(channel.to_proto())
3222                    } else {
3223                        None
3224                    }
3225                })
3226                .collect::<Vec<_>>();
3227
3228            if channels.is_empty() {
3229                continue;
3230            }
3231
3232            let update = proto::UpdateChannels {
3233                channels,
3234                ..Default::default()
3235            };
3236
3237            session.peer.send(connection_id, update.clone())?;
3238        }
3239    }
3240
3241    response.send(Ack {})?;
3242    Ok(())
3243}
3244
3245/// Get the list of channel members
3246async fn get_channel_members(
3247    request: proto::GetChannelMembers,
3248    response: Response<proto::GetChannelMembers>,
3249    session: MessageContext,
3250) -> Result<()> {
3251    let db = session.db().await;
3252    let channel_id = ChannelId::from_proto(request.channel_id);
3253    let limit = if request.limit == 0 {
3254        u16::MAX as u64
3255    } else {
3256        request.limit
3257    };
3258    let (members, users) = db
3259        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3260        .await?;
3261    response.send(proto::GetChannelMembersResponse { members, users })?;
3262    Ok(())
3263}
3264
3265/// Accept or decline a channel invitation.
3266async fn respond_to_channel_invite(
3267    request: proto::RespondToChannelInvite,
3268    response: Response<proto::RespondToChannelInvite>,
3269    session: MessageContext,
3270) -> Result<()> {
3271    let db = session.db().await;
3272    let channel_id = ChannelId::from_proto(request.channel_id);
3273    let RespondToChannelInvite {
3274        membership_update,
3275        notifications,
3276    } = db
3277        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3278        .await?;
3279
3280    let mut connection_pool = session.connection_pool().await;
3281    if let Some(membership_update) = membership_update {
3282        notify_membership_updated(
3283            &mut connection_pool,
3284            membership_update,
3285            session.user_id(),
3286            &session.peer,
3287        );
3288    } else {
3289        let update = proto::UpdateChannels {
3290            remove_channel_invitations: vec![channel_id.to_proto()],
3291            ..Default::default()
3292        };
3293
3294        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3295            session.peer.send(connection_id, update.clone())?;
3296        }
3297    };
3298
3299    send_notifications(&connection_pool, &session.peer, notifications);
3300
3301    response.send(proto::Ack {})?;
3302
3303    Ok(())
3304}
3305
3306/// Join the channels' room
3307async fn join_channel(
3308    request: proto::JoinChannel,
3309    response: Response<proto::JoinChannel>,
3310    session: MessageContext,
3311) -> Result<()> {
3312    let channel_id = ChannelId::from_proto(request.channel_id);
3313    join_channel_internal(channel_id, Box::new(response), session).await
3314}
3315
3316trait JoinChannelInternalResponse {
3317    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3318}
3319impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3320    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3321        Response::<proto::JoinChannel>::send(self, result)
3322    }
3323}
3324impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3325    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3326        Response::<proto::JoinRoom>::send(self, result)
3327    }
3328}
3329
3330async fn join_channel_internal(
3331    channel_id: ChannelId,
3332    response: Box<impl JoinChannelInternalResponse>,
3333    session: MessageContext,
3334) -> Result<()> {
3335    let joined_room = {
3336        let mut db = session.db().await;
3337        // If zed quits without leaving the room, and the user re-opens zed before the
3338        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3339        // room they were in.
3340        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3341            tracing::info!(
3342                stale_connection_id = %connection,
3343                "cleaning up stale connection",
3344            );
3345            drop(db);
3346            leave_room_for_session(&session, connection).await?;
3347            db = session.db().await;
3348        }
3349
3350        let (joined_room, membership_updated, role) = db
3351            .join_channel(channel_id, session.user_id(), session.connection_id)
3352            .await?;
3353
3354        let live_kit_connection_info =
3355            session
3356                .app_state
3357                .livekit_client
3358                .as_ref()
3359                .and_then(|live_kit| {
3360                    let (can_publish, token) = if role == ChannelRole::Guest {
3361                        (
3362                            false,
3363                            live_kit
3364                                .guest_token(
3365                                    &joined_room.room.livekit_room,
3366                                    &session.user_id().to_string(),
3367                                )
3368                                .trace_err()?,
3369                        )
3370                    } else {
3371                        (
3372                            true,
3373                            live_kit
3374                                .room_token(
3375                                    &joined_room.room.livekit_room,
3376                                    &session.user_id().to_string(),
3377                                )
3378                                .trace_err()?,
3379                        )
3380                    };
3381
3382                    Some(LiveKitConnectionInfo {
3383                        server_url: live_kit.url().into(),
3384                        token,
3385                        can_publish,
3386                    })
3387                });
3388
3389        response.send(proto::JoinRoomResponse {
3390            room: Some(joined_room.room.clone()),
3391            channel_id: joined_room
3392                .channel
3393                .as_ref()
3394                .map(|channel| channel.id.to_proto()),
3395            live_kit_connection_info,
3396        })?;
3397
3398        let mut connection_pool = session.connection_pool().await;
3399        if let Some(membership_updated) = membership_updated {
3400            notify_membership_updated(
3401                &mut connection_pool,
3402                membership_updated,
3403                session.user_id(),
3404                &session.peer,
3405            );
3406        }
3407
3408        room_updated(&joined_room.room, &session.peer);
3409
3410        joined_room
3411    };
3412
3413    channel_updated(
3414        &joined_room.channel.context("channel not returned")?,
3415        &joined_room.room,
3416        &session.peer,
3417        &*session.connection_pool().await,
3418    );
3419
3420    update_user_contacts(session.user_id(), &session).await?;
3421    Ok(())
3422}
3423
3424/// Start editing the channel notes
3425async fn join_channel_buffer(
3426    request: proto::JoinChannelBuffer,
3427    response: Response<proto::JoinChannelBuffer>,
3428    session: MessageContext,
3429) -> Result<()> {
3430    let db = session.db().await;
3431    let channel_id = ChannelId::from_proto(request.channel_id);
3432
3433    let open_response = db
3434        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3435        .await?;
3436
3437    let collaborators = open_response.collaborators.clone();
3438    response.send(open_response)?;
3439
3440    let update = UpdateChannelBufferCollaborators {
3441        channel_id: channel_id.to_proto(),
3442        collaborators: collaborators.clone(),
3443    };
3444    channel_buffer_updated(
3445        session.connection_id,
3446        collaborators
3447            .iter()
3448            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3449        &update,
3450        &session.peer,
3451    );
3452
3453    Ok(())
3454}
3455
3456/// Edit the channel notes
3457async fn update_channel_buffer(
3458    request: proto::UpdateChannelBuffer,
3459    session: MessageContext,
3460) -> Result<()> {
3461    let db = session.db().await;
3462    let channel_id = ChannelId::from_proto(request.channel_id);
3463
3464    let (collaborators, epoch, version) = db
3465        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3466        .await?;
3467
3468    channel_buffer_updated(
3469        session.connection_id,
3470        collaborators.clone(),
3471        &proto::UpdateChannelBuffer {
3472            channel_id: channel_id.to_proto(),
3473            operations: request.operations,
3474        },
3475        &session.peer,
3476    );
3477
3478    let pool = &*session.connection_pool().await;
3479
3480    let non_collaborators =
3481        pool.channel_connection_ids(channel_id)
3482            .filter_map(|(connection_id, _)| {
3483                if collaborators.contains(&connection_id) {
3484                    None
3485                } else {
3486                    Some(connection_id)
3487                }
3488            });
3489
3490    broadcast(None, non_collaborators, |peer_id| {
3491        session.peer.send(
3492            peer_id,
3493            proto::UpdateChannels {
3494                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3495                    channel_id: channel_id.to_proto(),
3496                    epoch: epoch as u64,
3497                    version: version.clone(),
3498                }],
3499                ..Default::default()
3500            },
3501        )
3502    });
3503
3504    Ok(())
3505}
3506
3507/// Rejoin the channel notes after a connection blip
3508async fn rejoin_channel_buffers(
3509    request: proto::RejoinChannelBuffers,
3510    response: Response<proto::RejoinChannelBuffers>,
3511    session: MessageContext,
3512) -> Result<()> {
3513    let db = session.db().await;
3514    let buffers = db
3515        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3516        .await?;
3517
3518    for rejoined_buffer in &buffers {
3519        let collaborators_to_notify = rejoined_buffer
3520            .buffer
3521            .collaborators
3522            .iter()
3523            .filter_map(|c| Some(c.peer_id?.into()));
3524        channel_buffer_updated(
3525            session.connection_id,
3526            collaborators_to_notify,
3527            &proto::UpdateChannelBufferCollaborators {
3528                channel_id: rejoined_buffer.buffer.channel_id,
3529                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3530            },
3531            &session.peer,
3532        );
3533    }
3534
3535    response.send(proto::RejoinChannelBuffersResponse {
3536        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3537    })?;
3538
3539    Ok(())
3540}
3541
3542/// Stop editing the channel notes
3543async fn leave_channel_buffer(
3544    request: proto::LeaveChannelBuffer,
3545    response: Response<proto::LeaveChannelBuffer>,
3546    session: MessageContext,
3547) -> Result<()> {
3548    let db = session.db().await;
3549    let channel_id = ChannelId::from_proto(request.channel_id);
3550
3551    let left_buffer = db
3552        .leave_channel_buffer(channel_id, session.connection_id)
3553        .await?;
3554
3555    response.send(Ack {})?;
3556
3557    channel_buffer_updated(
3558        session.connection_id,
3559        left_buffer.connections,
3560        &proto::UpdateChannelBufferCollaborators {
3561            channel_id: channel_id.to_proto(),
3562            collaborators: left_buffer.collaborators,
3563        },
3564        &session.peer,
3565    );
3566
3567    Ok(())
3568}
3569
3570fn channel_buffer_updated<T: EnvelopedMessage>(
3571    sender_id: ConnectionId,
3572    collaborators: impl IntoIterator<Item = ConnectionId>,
3573    message: &T,
3574    peer: &Peer,
3575) {
3576    broadcast(Some(sender_id), collaborators, |peer_id| {
3577        peer.send(peer_id, message.clone())
3578    });
3579}
3580
3581fn send_notifications(
3582    connection_pool: &ConnectionPool,
3583    peer: &Peer,
3584    notifications: db::NotificationBatch,
3585) {
3586    for (user_id, notification) in notifications {
3587        for connection_id in connection_pool.user_connection_ids(user_id) {
3588            if let Err(error) = peer.send(
3589                connection_id,
3590                proto::AddNotification {
3591                    notification: Some(notification.clone()),
3592                },
3593            ) {
3594                tracing::error!(
3595                    "failed to send notification to {:?} {}",
3596                    connection_id,
3597                    error
3598                );
3599            }
3600        }
3601    }
3602}
3603
3604/// Send a message to the channel
3605async fn send_channel_message(
3606    _request: proto::SendChannelMessage,
3607    _response: Response<proto::SendChannelMessage>,
3608    _session: MessageContext,
3609) -> Result<()> {
3610    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3611}
3612
3613/// Delete a channel message
3614async fn remove_channel_message(
3615    _request: proto::RemoveChannelMessage,
3616    _response: Response<proto::RemoveChannelMessage>,
3617    _session: MessageContext,
3618) -> Result<()> {
3619    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3620}
3621
3622async fn update_channel_message(
3623    _request: proto::UpdateChannelMessage,
3624    _response: Response<proto::UpdateChannelMessage>,
3625    _session: MessageContext,
3626) -> Result<()> {
3627    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3628}
3629
3630/// Mark a channel message as read
3631async fn acknowledge_channel_message(
3632    _request: proto::AckChannelMessage,
3633    _session: MessageContext,
3634) -> Result<()> {
3635    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3636}
3637
3638/// Mark a buffer version as synced
3639async fn acknowledge_buffer_version(
3640    request: proto::AckBufferOperation,
3641    session: MessageContext,
3642) -> Result<()> {
3643    let buffer_id = BufferId::from_proto(request.buffer_id);
3644    session
3645        .db()
3646        .await
3647        .observe_buffer_version(
3648            buffer_id,
3649            session.user_id(),
3650            request.epoch as i32,
3651            &request.version,
3652        )
3653        .await?;
3654    Ok(())
3655}
3656
3657/// Get a Supermaven API key for the user
3658async fn get_supermaven_api_key(
3659    _request: proto::GetSupermavenApiKey,
3660    response: Response<proto::GetSupermavenApiKey>,
3661    session: MessageContext,
3662) -> Result<()> {
3663    let user_id: String = session.user_id().to_string();
3664    if !session.is_staff() {
3665        return Err(anyhow!("supermaven not enabled for this account"))?;
3666    }
3667
3668    let email = session.email().context("user must have an email")?;
3669
3670    let supermaven_admin_api = session
3671        .supermaven_client
3672        .as_ref()
3673        .context("supermaven not configured")?;
3674
3675    let result = supermaven_admin_api
3676        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3677        .await?;
3678
3679    response.send(proto::GetSupermavenApiKeyResponse {
3680        api_key: result.api_key,
3681    })?;
3682
3683    Ok(())
3684}
3685
3686/// Start receiving chat updates for a channel
3687async fn join_channel_chat(
3688    _request: proto::JoinChannelChat,
3689    _response: Response<proto::JoinChannelChat>,
3690    _session: MessageContext,
3691) -> Result<()> {
3692    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3693}
3694
3695/// Stop receiving chat updates for a channel
3696async fn leave_channel_chat(
3697    _request: proto::LeaveChannelChat,
3698    _session: MessageContext,
3699) -> Result<()> {
3700    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3701}
3702
3703/// Retrieve the chat history for a channel
3704async fn get_channel_messages(
3705    _request: proto::GetChannelMessages,
3706    _response: Response<proto::GetChannelMessages>,
3707    _session: MessageContext,
3708) -> Result<()> {
3709    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3710}
3711
3712/// Retrieve specific chat messages
3713async fn get_channel_messages_by_id(
3714    _request: proto::GetChannelMessagesById,
3715    _response: Response<proto::GetChannelMessagesById>,
3716    _session: MessageContext,
3717) -> Result<()> {
3718    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3719}
3720
3721/// Retrieve the current users notifications
3722async fn get_notifications(
3723    request: proto::GetNotifications,
3724    response: Response<proto::GetNotifications>,
3725    session: MessageContext,
3726) -> Result<()> {
3727    let notifications = session
3728        .db()
3729        .await
3730        .get_notifications(
3731            session.user_id(),
3732            NOTIFICATION_COUNT_PER_PAGE,
3733            request.before_id.map(db::NotificationId::from_proto),
3734        )
3735        .await?;
3736    response.send(proto::GetNotificationsResponse {
3737        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3738        notifications,
3739    })?;
3740    Ok(())
3741}
3742
3743/// Mark notifications as read
3744async fn mark_notification_as_read(
3745    request: proto::MarkNotificationRead,
3746    response: Response<proto::MarkNotificationRead>,
3747    session: MessageContext,
3748) -> Result<()> {
3749    let database = &session.db().await;
3750    let notifications = database
3751        .mark_notification_as_read_by_id(
3752            session.user_id(),
3753            NotificationId::from_proto(request.notification_id),
3754        )
3755        .await?;
3756    send_notifications(
3757        &*session.connection_pool().await,
3758        &session.peer,
3759        notifications,
3760    );
3761    response.send(proto::Ack {})?;
3762    Ok(())
3763}
3764
3765fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3766    let message = match message {
3767        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3768        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3769        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3770        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3771        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3772            code: frame.code.into(),
3773            reason: frame.reason.as_str().to_owned().into(),
3774        })),
3775        // We should never receive a frame while reading the message, according
3776        // to the `tungstenite` maintainers:
3777        //
3778        // > It cannot occur when you read messages from the WebSocket, but it
3779        // > can be used when you want to send the raw frames (e.g. you want to
3780        // > send the frames to the WebSocket without composing the full message first).
3781        // >
3782        // > — https://github.com/snapview/tungstenite-rs/issues/268
3783        TungsteniteMessage::Frame(_) => {
3784            bail!("received an unexpected frame while reading the message")
3785        }
3786    };
3787
3788    Ok(message)
3789}
3790
3791fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3792    match message {
3793        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3794        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3795        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3796        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3797        AxumMessage::Close(frame) => {
3798            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3799                code: frame.code.into(),
3800                reason: frame.reason.as_ref().into(),
3801            }))
3802        }
3803    }
3804}
3805
3806fn notify_membership_updated(
3807    connection_pool: &mut ConnectionPool,
3808    result: MembershipUpdated,
3809    user_id: UserId,
3810    peer: &Peer,
3811) {
3812    for membership in &result.new_channels.channel_memberships {
3813        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3814    }
3815    for channel_id in &result.removed_channels {
3816        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3817    }
3818
3819    let user_channels_update = proto::UpdateUserChannels {
3820        channel_memberships: result
3821            .new_channels
3822            .channel_memberships
3823            .iter()
3824            .map(|cm| proto::ChannelMembership {
3825                channel_id: cm.channel_id.to_proto(),
3826                role: cm.role.into(),
3827            })
3828            .collect(),
3829        ..Default::default()
3830    };
3831
3832    let mut update = build_channels_update(result.new_channels);
3833    update.delete_channels = result
3834        .removed_channels
3835        .into_iter()
3836        .map(|id| id.to_proto())
3837        .collect();
3838    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3839
3840    for connection_id in connection_pool.user_connection_ids(user_id) {
3841        peer.send(connection_id, user_channels_update.clone())
3842            .trace_err();
3843        peer.send(connection_id, update.clone()).trace_err();
3844    }
3845}
3846
3847fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3848    proto::UpdateUserChannels {
3849        channel_memberships: channels
3850            .channel_memberships
3851            .iter()
3852            .map(|m| proto::ChannelMembership {
3853                channel_id: m.channel_id.to_proto(),
3854                role: m.role.into(),
3855            })
3856            .collect(),
3857        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3858    }
3859}
3860
3861fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3862    let mut update = proto::UpdateChannels::default();
3863
3864    for channel in channels.channels {
3865        update.channels.push(channel.to_proto());
3866    }
3867
3868    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3869
3870    for (channel_id, participants) in channels.channel_participants {
3871        update
3872            .channel_participants
3873            .push(proto::ChannelParticipants {
3874                channel_id: channel_id.to_proto(),
3875                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3876            });
3877    }
3878
3879    for channel in channels.invited_channels {
3880        update.channel_invitations.push(channel.to_proto());
3881    }
3882
3883    update
3884}
3885
3886fn build_initial_contacts_update(
3887    contacts: Vec<db::Contact>,
3888    pool: &ConnectionPool,
3889) -> proto::UpdateContacts {
3890    let mut update = proto::UpdateContacts::default();
3891
3892    for contact in contacts {
3893        match contact {
3894            db::Contact::Accepted { user_id, busy } => {
3895                update.contacts.push(contact_for_user(user_id, busy, pool));
3896            }
3897            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3898            db::Contact::Incoming { user_id } => {
3899                update
3900                    .incoming_requests
3901                    .push(proto::IncomingContactRequest {
3902                        requester_id: user_id.to_proto(),
3903                    })
3904            }
3905        }
3906    }
3907
3908    update
3909}
3910
3911fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3912    proto::Contact {
3913        user_id: user_id.to_proto(),
3914        online: pool.is_user_online(user_id),
3915        busy,
3916    }
3917}
3918
3919fn room_updated(room: &proto::Room, peer: &Peer) {
3920    broadcast(
3921        None,
3922        room.participants
3923            .iter()
3924            .filter_map(|participant| Some(participant.peer_id?.into())),
3925        |peer_id| {
3926            peer.send(
3927                peer_id,
3928                proto::RoomUpdated {
3929                    room: Some(room.clone()),
3930                },
3931            )
3932        },
3933    );
3934}
3935
3936fn channel_updated(
3937    channel: &db::channel::Model,
3938    room: &proto::Room,
3939    peer: &Peer,
3940    pool: &ConnectionPool,
3941) {
3942    let participants = room
3943        .participants
3944        .iter()
3945        .map(|p| p.user_id)
3946        .collect::<Vec<_>>();
3947
3948    broadcast(
3949        None,
3950        pool.channel_connection_ids(channel.root_id())
3951            .filter_map(|(channel_id, role)| {
3952                role.can_see_channel(channel.visibility)
3953                    .then_some(channel_id)
3954            }),
3955        |peer_id| {
3956            peer.send(
3957                peer_id,
3958                proto::UpdateChannels {
3959                    channel_participants: vec![proto::ChannelParticipants {
3960                        channel_id: channel.id.to_proto(),
3961                        participant_user_ids: participants.clone(),
3962                    }],
3963                    ..Default::default()
3964                },
3965            )
3966        },
3967    );
3968}
3969
3970async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3971    let db = session.db().await;
3972
3973    let contacts = db.get_contacts(user_id).await?;
3974    let busy = db.is_user_busy(user_id).await?;
3975
3976    let pool = session.connection_pool().await;
3977    let updated_contact = contact_for_user(user_id, busy, &pool);
3978    for contact in contacts {
3979        if let db::Contact::Accepted {
3980            user_id: contact_user_id,
3981            ..
3982        } = contact
3983        {
3984            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3985                session
3986                    .peer
3987                    .send(
3988                        contact_conn_id,
3989                        proto::UpdateContacts {
3990                            contacts: vec![updated_contact.clone()],
3991                            remove_contacts: Default::default(),
3992                            incoming_requests: Default::default(),
3993                            remove_incoming_requests: Default::default(),
3994                            outgoing_requests: Default::default(),
3995                            remove_outgoing_requests: Default::default(),
3996                        },
3997                    )
3998                    .trace_err();
3999            }
4000        }
4001    }
4002    Ok(())
4003}
4004
4005async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4006    let mut contacts_to_update = HashSet::default();
4007
4008    let room_id;
4009    let canceled_calls_to_user_ids;
4010    let livekit_room;
4011    let delete_livekit_room;
4012    let room;
4013    let channel;
4014
4015    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4016        contacts_to_update.insert(session.user_id());
4017
4018        for project in left_room.left_projects.values() {
4019            project_left(project, session);
4020        }
4021
4022        room_id = RoomId::from_proto(left_room.room.id);
4023        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4024        livekit_room = mem::take(&mut left_room.room.livekit_room);
4025        delete_livekit_room = left_room.deleted;
4026        room = mem::take(&mut left_room.room);
4027        channel = mem::take(&mut left_room.channel);
4028
4029        room_updated(&room, &session.peer);
4030    } else {
4031        return Ok(());
4032    }
4033
4034    if let Some(channel) = channel {
4035        channel_updated(
4036            &channel,
4037            &room,
4038            &session.peer,
4039            &*session.connection_pool().await,
4040        );
4041    }
4042
4043    {
4044        let pool = session.connection_pool().await;
4045        for canceled_user_id in canceled_calls_to_user_ids {
4046            for connection_id in pool.user_connection_ids(canceled_user_id) {
4047                session
4048                    .peer
4049                    .send(
4050                        connection_id,
4051                        proto::CallCanceled {
4052                            room_id: room_id.to_proto(),
4053                        },
4054                    )
4055                    .trace_err();
4056            }
4057            contacts_to_update.insert(canceled_user_id);
4058        }
4059    }
4060
4061    for contact_user_id in contacts_to_update {
4062        update_user_contacts(contact_user_id, session).await?;
4063    }
4064
4065    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4066        live_kit
4067            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4068            .await
4069            .trace_err();
4070
4071        if delete_livekit_room {
4072            live_kit.delete_room(livekit_room).await.trace_err();
4073        }
4074    }
4075
4076    Ok(())
4077}
4078
4079async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4080    let left_channel_buffers = session
4081        .db()
4082        .await
4083        .leave_channel_buffers(session.connection_id)
4084        .await?;
4085
4086    for left_buffer in left_channel_buffers {
4087        channel_buffer_updated(
4088            session.connection_id,
4089            left_buffer.connections,
4090            &proto::UpdateChannelBufferCollaborators {
4091                channel_id: left_buffer.channel_id.to_proto(),
4092                collaborators: left_buffer.collaborators,
4093            },
4094            &session.peer,
4095        );
4096    }
4097
4098    Ok(())
4099}
4100
4101fn project_left(project: &db::LeftProject, session: &Session) {
4102    for connection_id in &project.connection_ids {
4103        if project.should_unshare {
4104            session
4105                .peer
4106                .send(
4107                    *connection_id,
4108                    proto::UnshareProject {
4109                        project_id: project.id.to_proto(),
4110                    },
4111                )
4112                .trace_err();
4113        } else {
4114            session
4115                .peer
4116                .send(
4117                    *connection_id,
4118                    proto::RemoveProjectCollaborator {
4119                        project_id: project.id.to_proto(),
4120                        peer_id: Some(session.connection_id.into()),
4121                    },
4122                )
4123                .trace_err();
4124        }
4125    }
4126}
4127
4128pub trait ResultExt {
4129    type Ok;
4130
4131    fn trace_err(self) -> Option<Self::Ok>;
4132}
4133
4134impl<T, E> ResultExt for Result<T, E>
4135where
4136    E: std::fmt::Debug,
4137{
4138    type Ok = T;
4139
4140    #[track_caller]
4141    fn trace_err(self) -> Option<T> {
4142        match self {
4143            Ok(value) => Some(value),
4144            Err(error) => {
4145                tracing::error!("{:?}", error);
4146                None
4147            }
4148        }
4149    }
4150}