rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
  10    },
  11    executor::Executor,
  12};
  13use anyhow::{Context as _, anyhow, bail};
  14use async_tungstenite::tungstenite::{
  15    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  16};
  17use axum::headers::UserAgent;
  18use axum::{
  19    Extension, Router, TypedHeader,
  20    body::Body,
  21    extract::{
  22        ConnectInfo, WebSocketUpgrade,
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use core::fmt::{self, Debug, Formatter};
  34use futures::TryFutureExt as _;
  35use reqwest_client::ReqwestClient;
  36use rpc::proto::split_repository_update;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38use tracing::Span;
  39use util::paths::PathStyle;
  40
  41use futures::{
  42    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  43    stream::FuturesUnordered,
  44};
  45use prometheus::{IntGauge, register_int_gauge};
  46use rpc::{
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52};
  53use semver::Version;
  54use serde::{Serialize, Serializer};
  55use std::{
  56    any::TypeId,
  57    future::Future,
  58    marker::PhantomData,
  59    mem,
  60    net::SocketAddr,
  61    ops::{Deref, DerefMut},
  62    rc::Rc,
  63    sync::{
  64        Arc, OnceLock,
  65        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  66    },
  67    time::{Duration, Instant},
  68};
  69use tokio::sync::{Semaphore, watch};
  70use tower::ServiceBuilder;
  71use tracing::{
  72    Instrument,
  73    field::{self},
  74    info_span, instrument,
  75};
  76
  77pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  78
  79// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  80pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  81
  82const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  83const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  84
  85static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  86
  87const TOTAL_DURATION_MS: &str = "total_duration_ms";
  88const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  89const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  90const HOST_WAITING_MS: &str = "host_waiting_ms";
  91
  92type MessageHandler =
  93    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  94
  95pub struct ConnectionGuard;
  96
  97impl ConnectionGuard {
  98    pub fn try_acquire() -> Result<Self, ()> {
  99        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 100        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 101            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 102            tracing::error!(
 103                "too many concurrent connections: {}",
 104                current_connections + 1
 105            );
 106            return Err(());
 107        }
 108        Ok(ConnectionGuard)
 109    }
 110}
 111
 112impl Drop for ConnectionGuard {
 113    fn drop(&mut self) {
 114        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 115    }
 116}
 117
 118struct Response<R> {
 119    peer: Arc<Peer>,
 120    receipt: Receipt<R>,
 121    responded: Arc<AtomicBool>,
 122}
 123
 124impl<R: RequestMessage> Response<R> {
 125    fn send(self, payload: R::Response) -> Result<()> {
 126        self.responded.store(true, SeqCst);
 127        self.peer.respond(self.receipt, payload)?;
 128        Ok(())
 129    }
 130}
 131
 132#[derive(Clone, Debug)]
 133pub enum Principal {
 134    User(User),
 135    Impersonated { user: User, admin: User },
 136}
 137
 138impl Principal {
 139    fn update_span(&self, span: &tracing::Span) {
 140        match &self {
 141            Principal::User(user) => {
 142                span.record("user_id", user.id.0);
 143                span.record("login", &user.github_login);
 144            }
 145            Principal::Impersonated { user, admin } => {
 146                span.record("user_id", user.id.0);
 147                span.record("login", &user.github_login);
 148                span.record("impersonator", &admin.github_login);
 149            }
 150        }
 151    }
 152}
 153
 154#[derive(Clone)]
 155struct MessageContext {
 156    session: Session,
 157    span: tracing::Span,
 158}
 159
 160impl Deref for MessageContext {
 161    type Target = Session;
 162
 163    fn deref(&self) -> &Self::Target {
 164        &self.session
 165    }
 166}
 167
 168impl MessageContext {
 169    pub fn forward_request<T: RequestMessage>(
 170        &self,
 171        receiver_id: ConnectionId,
 172        request: T,
 173    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 174        let request_start_time = Instant::now();
 175        let span = self.span.clone();
 176        tracing::info!("start forwarding request");
 177        self.peer
 178            .forward_request(self.connection_id, receiver_id, request)
 179            .inspect(move |_| {
 180                span.record(
 181                    HOST_WAITING_MS,
 182                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 183                );
 184            })
 185            .inspect_err(|_| tracing::error!("error forwarding request"))
 186            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 187    }
 188}
 189
 190#[derive(Clone)]
 191struct Session {
 192    principal: Principal,
 193    connection_id: ConnectionId,
 194    db: Arc<tokio::sync::Mutex<DbHandle>>,
 195    peer: Arc<Peer>,
 196    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 197    app_state: Arc<AppState>,
 198    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 199    /// The GeoIP country code for the user.
 200    #[allow(unused)]
 201    geoip_country_code: Option<String>,
 202    #[allow(unused)]
 203    system_id: Option<String>,
 204    _executor: Executor,
 205}
 206
 207impl Session {
 208    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 209        #[cfg(test)]
 210        tokio::task::yield_now().await;
 211        let guard = self.db.lock().await;
 212        #[cfg(test)]
 213        tokio::task::yield_now().await;
 214        guard
 215    }
 216
 217    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 218        #[cfg(test)]
 219        tokio::task::yield_now().await;
 220        let guard = self.connection_pool.lock();
 221        ConnectionPoolGuard {
 222            guard,
 223            _not_send: PhantomData,
 224        }
 225    }
 226
 227    fn is_staff(&self) -> bool {
 228        match &self.principal {
 229            Principal::User(user) => user.admin,
 230            Principal::Impersonated { .. } => true,
 231        }
 232    }
 233
 234    fn user_id(&self) -> UserId {
 235        match &self.principal {
 236            Principal::User(user) => user.id,
 237            Principal::Impersonated { user, .. } => user.id,
 238        }
 239    }
 240
 241    pub fn email(&self) -> Option<String> {
 242        match &self.principal {
 243            Principal::User(user) => user.email_address.clone(),
 244            Principal::Impersonated { user, .. } => user.email_address.clone(),
 245        }
 246    }
 247}
 248
 249impl Debug for Session {
 250    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 251        let mut result = f.debug_struct("Session");
 252        match &self.principal {
 253            Principal::User(user) => {
 254                result.field("user", &user.github_login);
 255            }
 256            Principal::Impersonated { user, admin } => {
 257                result.field("user", &user.github_login);
 258                result.field("impersonator", &admin.github_login);
 259            }
 260        }
 261        result.field("connection_id", &self.connection_id).finish()
 262    }
 263}
 264
 265struct DbHandle(Arc<Database>);
 266
 267impl Deref for DbHandle {
 268    type Target = Database;
 269
 270    fn deref(&self) -> &Self::Target {
 271        self.0.as_ref()
 272    }
 273}
 274
 275pub struct Server {
 276    id: parking_lot::Mutex<ServerId>,
 277    peer: Arc<Peer>,
 278    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 279    app_state: Arc<AppState>,
 280    handlers: HashMap<TypeId, MessageHandler>,
 281    teardown: watch::Sender<bool>,
 282}
 283
 284pub(crate) struct ConnectionPoolGuard<'a> {
 285    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 286    _not_send: PhantomData<Rc<()>>,
 287}
 288
 289#[derive(Serialize)]
 290pub struct ServerSnapshot<'a> {
 291    peer: &'a Peer,
 292    #[serde(serialize_with = "serialize_deref")]
 293    connection_pool: ConnectionPoolGuard<'a>,
 294}
 295
 296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 297where
 298    S: Serializer,
 299    T: Deref<Target = U>,
 300    U: Serialize,
 301{
 302    Serialize::serialize(value.deref(), serializer)
 303}
 304
 305impl Server {
 306    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 307        let mut server = Self {
 308            id: parking_lot::Mutex::new(id),
 309            peer: Peer::new(id.0 as u32),
 310            app_state,
 311            connection_pool: Default::default(),
 312            handlers: Default::default(),
 313            teardown: watch::channel(false).0,
 314        };
 315
 316        server
 317            .add_request_handler(ping)
 318            .add_request_handler(create_room)
 319            .add_request_handler(join_room)
 320            .add_request_handler(rejoin_room)
 321            .add_request_handler(leave_room)
 322            .add_request_handler(set_room_participant_role)
 323            .add_request_handler(call)
 324            .add_request_handler(cancel_call)
 325            .add_message_handler(decline_call)
 326            .add_request_handler(update_participant_location)
 327            .add_request_handler(share_project)
 328            .add_message_handler(unshare_project)
 329            .add_request_handler(join_project)
 330            .add_message_handler(leave_project)
 331            .add_request_handler(update_project)
 332            .add_request_handler(update_worktree)
 333            .add_request_handler(update_repository)
 334            .add_request_handler(remove_repository)
 335            .add_message_handler(start_language_server)
 336            .add_message_handler(update_language_server)
 337            .add_message_handler(update_diagnostic_summary)
 338            .add_message_handler(update_worktree_settings)
 339            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 340            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 341            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 342            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 343            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 344            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 345            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 346            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 347            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 348            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 349            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 350            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 351            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 352            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 353            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 354            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 355            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 356            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 357            .add_request_handler(
 358                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 359            )
 360            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 361            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 362            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 363            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 364            .add_request_handler(
 365                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 366            )
 367            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 368            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 369            .add_request_handler(
 370                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 371            )
 372            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 373            .add_request_handler(
 374                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 375            )
 376            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 377            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 378            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 379            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 380            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 381            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 382            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 383            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 384            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 385            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 386            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 387            .add_request_handler(
 388                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 389            )
 390            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 391            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 392            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 393            .add_request_handler(lsp_query)
 394            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 395            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 396            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 397            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 398            .add_message_handler(create_buffer_for_peer)
 399            .add_message_handler(create_image_for_peer)
 400            .add_request_handler(update_buffer)
 401            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 402            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 403            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 404            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 405            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 406            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 407            .add_message_handler(
 408                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 409            )
 410            .add_request_handler(get_users)
 411            .add_request_handler(fuzzy_search_users)
 412            .add_request_handler(request_contact)
 413            .add_request_handler(remove_contact)
 414            .add_request_handler(respond_to_contact_request)
 415            .add_message_handler(subscribe_to_channels)
 416            .add_request_handler(create_channel)
 417            .add_request_handler(delete_channel)
 418            .add_request_handler(invite_channel_member)
 419            .add_request_handler(remove_channel_member)
 420            .add_request_handler(set_channel_member_role)
 421            .add_request_handler(set_channel_visibility)
 422            .add_request_handler(rename_channel)
 423            .add_request_handler(join_channel_buffer)
 424            .add_request_handler(leave_channel_buffer)
 425            .add_message_handler(update_channel_buffer)
 426            .add_request_handler(rejoin_channel_buffers)
 427            .add_request_handler(get_channel_members)
 428            .add_request_handler(respond_to_channel_invite)
 429            .add_request_handler(join_channel)
 430            .add_request_handler(join_channel_chat)
 431            .add_message_handler(leave_channel_chat)
 432            .add_request_handler(send_channel_message)
 433            .add_request_handler(remove_channel_message)
 434            .add_request_handler(update_channel_message)
 435            .add_request_handler(get_channel_messages)
 436            .add_request_handler(get_channel_messages_by_id)
 437            .add_request_handler(get_notifications)
 438            .add_request_handler(mark_notification_as_read)
 439            .add_request_handler(move_channel)
 440            .add_request_handler(reorder_channel)
 441            .add_request_handler(follow)
 442            .add_message_handler(unfollow)
 443            .add_message_handler(update_followers)
 444            .add_message_handler(acknowledge_channel_message)
 445            .add_message_handler(acknowledge_buffer_version)
 446            .add_request_handler(get_supermaven_api_key)
 447            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 448            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 449            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 450            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 451            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 452            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 453            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 454            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 455            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 456            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 457            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 458            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 459            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 460            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 461            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 462            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 463            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 464            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 465            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 466            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 467            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 468            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 469            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 470            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 471            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 472            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 473            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 474            .add_message_handler(update_context)
 475            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 476            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
 477
 478        Arc::new(server)
 479    }
 480
 481    pub async fn start(&self) -> Result<()> {
 482        let server_id = *self.id.lock();
 483        let app_state = self.app_state.clone();
 484        let peer = self.peer.clone();
 485        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 486        let pool = self.connection_pool.clone();
 487        let livekit_client = self.app_state.livekit_client.clone();
 488
 489        let span = info_span!("start server");
 490        self.app_state.executor.spawn_detached(
 491            async move {
 492                tracing::info!("waiting for cleanup timeout");
 493                timeout.await;
 494                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 495
 496                app_state
 497                    .db
 498                    .delete_stale_channel_chat_participants(
 499                        &app_state.config.zed_environment,
 500                        server_id,
 501                    )
 502                    .await
 503                    .trace_err();
 504
 505                if let Some((room_ids, channel_ids)) = app_state
 506                    .db
 507                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 508                    .await
 509                    .trace_err()
 510                {
 511                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 512                    tracing::info!(
 513                        stale_channel_buffer_count = channel_ids.len(),
 514                        "retrieved stale channel buffers"
 515                    );
 516
 517                    for channel_id in channel_ids {
 518                        if let Some(refreshed_channel_buffer) = app_state
 519                            .db
 520                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 521                            .await
 522                            .trace_err()
 523                        {
 524                            for connection_id in refreshed_channel_buffer.connection_ids {
 525                                peer.send(
 526                                    connection_id,
 527                                    proto::UpdateChannelBufferCollaborators {
 528                                        channel_id: channel_id.to_proto(),
 529                                        collaborators: refreshed_channel_buffer
 530                                            .collaborators
 531                                            .clone(),
 532                                    },
 533                                )
 534                                .trace_err();
 535                            }
 536                        }
 537                    }
 538
 539                    for room_id in room_ids {
 540                        let mut contacts_to_update = HashSet::default();
 541                        let mut canceled_calls_to_user_ids = Vec::new();
 542                        let mut livekit_room = String::new();
 543                        let mut delete_livekit_room = false;
 544
 545                        if let Some(mut refreshed_room) = app_state
 546                            .db
 547                            .clear_stale_room_participants(room_id, server_id)
 548                            .await
 549                            .trace_err()
 550                        {
 551                            tracing::info!(
 552                                room_id = room_id.0,
 553                                new_participant_count = refreshed_room.room.participants.len(),
 554                                "refreshed room"
 555                            );
 556                            room_updated(&refreshed_room.room, &peer);
 557                            if let Some(channel) = refreshed_room.channel.as_ref() {
 558                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 559                            }
 560                            contacts_to_update
 561                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 562                            contacts_to_update
 563                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 564                            canceled_calls_to_user_ids =
 565                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 566                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 567                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 568                        }
 569
 570                        {
 571                            let pool = pool.lock();
 572                            for canceled_user_id in canceled_calls_to_user_ids {
 573                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 574                                    peer.send(
 575                                        connection_id,
 576                                        proto::CallCanceled {
 577                                            room_id: room_id.to_proto(),
 578                                        },
 579                                    )
 580                                    .trace_err();
 581                                }
 582                            }
 583                        }
 584
 585                        for user_id in contacts_to_update {
 586                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 587                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 588                            if let Some((busy, contacts)) = busy.zip(contacts) {
 589                                let pool = pool.lock();
 590                                let updated_contact = contact_for_user(user_id, busy, &pool);
 591                                for contact in contacts {
 592                                    if let db::Contact::Accepted {
 593                                        user_id: contact_user_id,
 594                                        ..
 595                                    } = contact
 596                                    {
 597                                        for contact_conn_id in
 598                                            pool.user_connection_ids(contact_user_id)
 599                                        {
 600                                            peer.send(
 601                                                contact_conn_id,
 602                                                proto::UpdateContacts {
 603                                                    contacts: vec![updated_contact.clone()],
 604                                                    remove_contacts: Default::default(),
 605                                                    incoming_requests: Default::default(),
 606                                                    remove_incoming_requests: Default::default(),
 607                                                    outgoing_requests: Default::default(),
 608                                                    remove_outgoing_requests: Default::default(),
 609                                                },
 610                                            )
 611                                            .trace_err();
 612                                        }
 613                                    }
 614                                }
 615                            }
 616                        }
 617
 618                        if let Some(live_kit) = livekit_client.as_ref()
 619                            && delete_livekit_room
 620                        {
 621                            live_kit.delete_room(livekit_room).await.trace_err();
 622                        }
 623                    }
 624                }
 625
 626                app_state
 627                    .db
 628                    .delete_stale_channel_chat_participants(
 629                        &app_state.config.zed_environment,
 630                        server_id,
 631                    )
 632                    .await
 633                    .trace_err();
 634
 635                app_state
 636                    .db
 637                    .clear_old_worktree_entries(server_id)
 638                    .await
 639                    .trace_err();
 640
 641                app_state
 642                    .db
 643                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 644                    .await
 645                    .trace_err();
 646            }
 647            .instrument(span),
 648        );
 649        Ok(())
 650    }
 651
 652    pub fn teardown(&self) {
 653        self.peer.teardown();
 654        self.connection_pool.lock().reset();
 655        let _ = self.teardown.send(true);
 656    }
 657
 658    #[cfg(test)]
 659    pub fn reset(&self, id: ServerId) {
 660        self.teardown();
 661        *self.id.lock() = id;
 662        self.peer.reset(id.0 as u32);
 663        let _ = self.teardown.send(false);
 664    }
 665
 666    #[cfg(test)]
 667    pub fn id(&self) -> ServerId {
 668        *self.id.lock()
 669    }
 670
 671    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 672    where
 673        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 674        Fut: 'static + Send + Future<Output = Result<()>>,
 675        M: EnvelopedMessage,
 676    {
 677        let prev_handler = self.handlers.insert(
 678            TypeId::of::<M>(),
 679            Box::new(move |envelope, session, span| {
 680                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 681                let received_at = envelope.received_at;
 682                tracing::info!("message received");
 683                let start_time = Instant::now();
 684                let future = (handler)(
 685                    *envelope,
 686                    MessageContext {
 687                        session,
 688                        span: span.clone(),
 689                    },
 690                );
 691                async move {
 692                    let result = future.await;
 693                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 694                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 695                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 696                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 697                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 698                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 699                    match result {
 700                        Err(error) => {
 701                            tracing::error!(?error, "error handling message")
 702                        }
 703                        Ok(()) => tracing::info!("finished handling message"),
 704                    }
 705                }
 706                .boxed()
 707            }),
 708        );
 709        if prev_handler.is_some() {
 710            panic!("registered a handler for the same message twice");
 711        }
 712        self
 713    }
 714
 715    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 716    where
 717        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 718        Fut: 'static + Send + Future<Output = Result<()>>,
 719        M: EnvelopedMessage,
 720    {
 721        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 722        self
 723    }
 724
 725    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 726    where
 727        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 728        Fut: Send + Future<Output = Result<()>>,
 729        M: RequestMessage,
 730    {
 731        let handler = Arc::new(handler);
 732        self.add_handler(move |envelope, session| {
 733            let receipt = envelope.receipt();
 734            let handler = handler.clone();
 735            async move {
 736                let peer = session.peer.clone();
 737                let responded = Arc::new(AtomicBool::default());
 738                let response = Response {
 739                    peer: peer.clone(),
 740                    responded: responded.clone(),
 741                    receipt,
 742                };
 743                match (handler)(envelope.payload, response, session).await {
 744                    Ok(()) => {
 745                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 746                            Ok(())
 747                        } else {
 748                            Err(anyhow!("handler did not send a response"))?
 749                        }
 750                    }
 751                    Err(error) => {
 752                        let proto_err = match &error {
 753                            Error::Internal(err) => err.to_proto(),
 754                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 755                        };
 756                        peer.respond_with_error(receipt, proto_err)?;
 757                        Err(error)
 758                    }
 759                }
 760            }
 761        })
 762    }
 763
 764    pub fn handle_connection(
 765        self: &Arc<Self>,
 766        connection: Connection,
 767        address: String,
 768        principal: Principal,
 769        zed_version: ZedVersion,
 770        release_channel: Option<String>,
 771        user_agent: Option<String>,
 772        geoip_country_code: Option<String>,
 773        system_id: Option<String>,
 774        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 775        executor: Executor,
 776        connection_guard: Option<ConnectionGuard>,
 777    ) -> impl Future<Output = ()> + use<> {
 778        let this = self.clone();
 779        let span = info_span!("handle connection", %address,
 780            connection_id=field::Empty,
 781            user_id=field::Empty,
 782            login=field::Empty,
 783            impersonator=field::Empty,
 784            user_agent=field::Empty,
 785            geoip_country_code=field::Empty,
 786            release_channel=field::Empty,
 787        );
 788        principal.update_span(&span);
 789        if let Some(user_agent) = user_agent {
 790            span.record("user_agent", user_agent);
 791        }
 792        if let Some(release_channel) = release_channel {
 793            span.record("release_channel", release_channel);
 794        }
 795
 796        if let Some(country_code) = geoip_country_code.as_ref() {
 797            span.record("geoip_country_code", country_code);
 798        }
 799
 800        let mut teardown = self.teardown.subscribe();
 801        async move {
 802            if *teardown.borrow() {
 803                tracing::error!("server is tearing down");
 804                return;
 805            }
 806
 807            let (connection_id, handle_io, mut incoming_rx) =
 808                this.peer.add_connection(connection, {
 809                    let executor = executor.clone();
 810                    move |duration| executor.sleep(duration)
 811                });
 812            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 813
 814            tracing::info!("connection opened");
 815
 816            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 817            let http_client = match ReqwestClient::user_agent(&user_agent) {
 818                Ok(http_client) => Arc::new(http_client),
 819                Err(error) => {
 820                    tracing::error!(?error, "failed to create HTTP client");
 821                    return;
 822                }
 823            };
 824
 825            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 826                |supermaven_admin_api_key| {
 827                    Arc::new(SupermavenAdminApi::new(
 828                        supermaven_admin_api_key.to_string(),
 829                        http_client.clone(),
 830                    ))
 831                },
 832            );
 833
 834            let session = Session {
 835                principal: principal.clone(),
 836                connection_id,
 837                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 838                peer: this.peer.clone(),
 839                connection_pool: this.connection_pool.clone(),
 840                app_state: this.app_state.clone(),
 841                geoip_country_code,
 842                system_id,
 843                _executor: executor.clone(),
 844                supermaven_client,
 845            };
 846
 847            if let Err(error) = this
 848                .send_initial_client_update(
 849                    connection_id,
 850                    zed_version,
 851                    send_connection_id,
 852                    &session,
 853                )
 854                .await
 855            {
 856                tracing::error!(?error, "failed to send initial client update");
 857                return;
 858            }
 859            drop(connection_guard);
 860
 861            let handle_io = handle_io.fuse();
 862            futures::pin_mut!(handle_io);
 863
 864            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 865            // This prevents deadlocks when e.g., client A performs a request to client B and
 866            // client B performs a request to client A. If both clients stop processing further
 867            // messages until their respective request completes, they won't have a chance to
 868            // respond to the other client's request and cause a deadlock.
 869            //
 870            // This arrangement ensures we will attempt to process earlier messages first, but fall
 871            // back to processing messages arrived later in the spirit of making progress.
 872            const MAX_CONCURRENT_HANDLERS: usize = 256;
 873            let mut foreground_message_handlers = FuturesUnordered::new();
 874            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 875            let get_concurrent_handlers = {
 876                let concurrent_handlers = concurrent_handlers.clone();
 877                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 878            };
 879            loop {
 880                let next_message = async {
 881                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 882                    let message = incoming_rx.next().await;
 883                    // Cache the concurrent_handlers here, so that we know what the
 884                    // queue looks like as each handler starts
 885                    (permit, message, get_concurrent_handlers())
 886                }
 887                .fuse();
 888                futures::pin_mut!(next_message);
 889                futures::select_biased! {
 890                    _ = teardown.changed().fuse() => return,
 891                    result = handle_io => {
 892                        if let Err(error) = result {
 893                            tracing::error!(?error, "error handling I/O");
 894                        }
 895                        break;
 896                    }
 897                    _ = foreground_message_handlers.next() => {}
 898                    next_message = next_message => {
 899                        let (permit, message, concurrent_handlers) = next_message;
 900                        if let Some(message) = message {
 901                            let type_name = message.payload_type_name();
 902                            // note: we copy all the fields from the parent span so we can query them in the logs.
 903                            // (https://github.com/tokio-rs/tracing/issues/2670).
 904                            let span = tracing::info_span!("receive message",
 905                                %connection_id,
 906                                %address,
 907                                type_name,
 908                                concurrent_handlers,
 909                                user_id=field::Empty,
 910                                login=field::Empty,
 911                                impersonator=field::Empty,
 912                                lsp_query_request=field::Empty,
 913                                release_channel=field::Empty,
 914                                { TOTAL_DURATION_MS }=field::Empty,
 915                                { PROCESSING_DURATION_MS }=field::Empty,
 916                                { QUEUE_DURATION_MS }=field::Empty,
 917                                { HOST_WAITING_MS }=field::Empty
 918                            );
 919                            principal.update_span(&span);
 920                            let span_enter = span.enter();
 921                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 922                                let is_background = message.is_background();
 923                                let handle_message = (handler)(message, session.clone(), span.clone());
 924                                drop(span_enter);
 925
 926                                let handle_message = async move {
 927                                    handle_message.await;
 928                                    drop(permit);
 929                                }.instrument(span);
 930                                if is_background {
 931                                    executor.spawn_detached(handle_message);
 932                                } else {
 933                                    foreground_message_handlers.push(handle_message);
 934                                }
 935                            } else {
 936                                tracing::error!("no message handler");
 937                            }
 938                        } else {
 939                            tracing::info!("connection closed");
 940                            break;
 941                        }
 942                    }
 943                }
 944            }
 945
 946            drop(foreground_message_handlers);
 947            let concurrent_handlers = get_concurrent_handlers();
 948            tracing::info!(concurrent_handlers, "signing out");
 949            if let Err(error) = connection_lost(session, teardown, executor).await {
 950                tracing::error!(?error, "error signing out");
 951            }
 952        }
 953        .instrument(span)
 954    }
 955
 956    async fn send_initial_client_update(
 957        &self,
 958        connection_id: ConnectionId,
 959        zed_version: ZedVersion,
 960        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 961        session: &Session,
 962    ) -> Result<()> {
 963        self.peer.send(
 964            connection_id,
 965            proto::Hello {
 966                peer_id: Some(connection_id.into()),
 967            },
 968        )?;
 969        tracing::info!("sent hello message");
 970        if let Some(send_connection_id) = send_connection_id.take() {
 971            let _ = send_connection_id.send(connection_id);
 972        }
 973
 974        match &session.principal {
 975            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 976                if !user.connected_once {
 977                    self.peer.send(connection_id, proto::ShowContacts {})?;
 978                    self.app_state
 979                        .db
 980                        .set_user_connected_once(user.id, true)
 981                        .await?;
 982                }
 983
 984                let contacts = self.app_state.db.get_contacts(user.id).await?;
 985
 986                {
 987                    let mut pool = self.connection_pool.lock();
 988                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 989                    self.peer.send(
 990                        connection_id,
 991                        build_initial_contacts_update(contacts, &pool),
 992                    )?;
 993                }
 994
 995                if should_auto_subscribe_to_channels(&zed_version) {
 996                    subscribe_user_to_channels(user.id, session).await?;
 997                }
 998
 999                if let Some(incoming_call) =
1000                    self.app_state.db.incoming_call_for_user(user.id).await?
1001                {
1002                    self.peer.send(connection_id, incoming_call)?;
1003                }
1004
1005                update_user_contacts(user.id, session).await?;
1006            }
1007        }
1008
1009        Ok(())
1010    }
1011
1012    pub async fn invite_code_redeemed(
1013        self: &Arc<Self>,
1014        inviter_id: UserId,
1015        invitee_id: UserId,
1016    ) -> Result<()> {
1017        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1018            && let Some(code) = &user.invite_code
1019        {
1020            let pool = self.connection_pool.lock();
1021            let invitee_contact = contact_for_user(invitee_id, false, &pool);
1022            for connection_id in pool.user_connection_ids(inviter_id) {
1023                self.peer.send(
1024                    connection_id,
1025                    proto::UpdateContacts {
1026                        contacts: vec![invitee_contact.clone()],
1027                        ..Default::default()
1028                    },
1029                )?;
1030                self.peer.send(
1031                    connection_id,
1032                    proto::UpdateInviteInfo {
1033                        url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1034                        count: user.invite_count as u32,
1035                    },
1036                )?;
1037            }
1038        }
1039        Ok(())
1040    }
1041
1042    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1043        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1044            && let Some(invite_code) = &user.invite_code
1045        {
1046            let pool = self.connection_pool.lock();
1047            for connection_id in pool.user_connection_ids(user_id) {
1048                self.peer.send(
1049                    connection_id,
1050                    proto::UpdateInviteInfo {
1051                        url: format!(
1052                            "{}{}",
1053                            self.app_state.config.invite_link_prefix, invite_code
1054                        ),
1055                        count: user.invite_count as u32,
1056                    },
1057                )?;
1058            }
1059        }
1060        Ok(())
1061    }
1062
1063    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1064        ServerSnapshot {
1065            connection_pool: ConnectionPoolGuard {
1066                guard: self.connection_pool.lock(),
1067                _not_send: PhantomData,
1068            },
1069            peer: &self.peer,
1070        }
1071    }
1072}
1073
1074impl Deref for ConnectionPoolGuard<'_> {
1075    type Target = ConnectionPool;
1076
1077    fn deref(&self) -> &Self::Target {
1078        &self.guard
1079    }
1080}
1081
1082impl DerefMut for ConnectionPoolGuard<'_> {
1083    fn deref_mut(&mut self) -> &mut Self::Target {
1084        &mut self.guard
1085    }
1086}
1087
1088impl Drop for ConnectionPoolGuard<'_> {
1089    fn drop(&mut self) {
1090        #[cfg(test)]
1091        self.check_invariants();
1092    }
1093}
1094
1095fn broadcast<F>(
1096    sender_id: Option<ConnectionId>,
1097    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1098    mut f: F,
1099) where
1100    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1101{
1102    for receiver_id in receiver_ids {
1103        if Some(receiver_id) != sender_id
1104            && let Err(error) = f(receiver_id)
1105        {
1106            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1107        }
1108    }
1109}
1110
1111pub struct ProtocolVersion(u32);
1112
1113impl Header for ProtocolVersion {
1114    fn name() -> &'static HeaderName {
1115        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1116        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1117    }
1118
1119    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1120    where
1121        Self: Sized,
1122        I: Iterator<Item = &'i axum::http::HeaderValue>,
1123    {
1124        let version = values
1125            .next()
1126            .ok_or_else(axum::headers::Error::invalid)?
1127            .to_str()
1128            .map_err(|_| axum::headers::Error::invalid())?
1129            .parse()
1130            .map_err(|_| axum::headers::Error::invalid())?;
1131        Ok(Self(version))
1132    }
1133
1134    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1135        values.extend([self.0.to_string().parse().unwrap()]);
1136    }
1137}
1138
1139pub struct AppVersionHeader(Version);
1140impl Header for AppVersionHeader {
1141    fn name() -> &'static HeaderName {
1142        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1143        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1144    }
1145
1146    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1147    where
1148        Self: Sized,
1149        I: Iterator<Item = &'i axum::http::HeaderValue>,
1150    {
1151        let version = values
1152            .next()
1153            .ok_or_else(axum::headers::Error::invalid)?
1154            .to_str()
1155            .map_err(|_| axum::headers::Error::invalid())?
1156            .parse()
1157            .map_err(|_| axum::headers::Error::invalid())?;
1158        Ok(Self(version))
1159    }
1160
1161    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1162        values.extend([self.0.to_string().parse().unwrap()]);
1163    }
1164}
1165
1166#[derive(Debug)]
1167pub struct ReleaseChannelHeader(String);
1168
1169impl Header for ReleaseChannelHeader {
1170    fn name() -> &'static HeaderName {
1171        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1172        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1173    }
1174
1175    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1176    where
1177        Self: Sized,
1178        I: Iterator<Item = &'i axum::http::HeaderValue>,
1179    {
1180        Ok(Self(
1181            values
1182                .next()
1183                .ok_or_else(axum::headers::Error::invalid)?
1184                .to_str()
1185                .map_err(|_| axum::headers::Error::invalid())?
1186                .to_owned(),
1187        ))
1188    }
1189
1190    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1191        values.extend([self.0.parse().unwrap()]);
1192    }
1193}
1194
1195pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1196    Router::new()
1197        .route("/rpc", get(handle_websocket_request))
1198        .layer(
1199            ServiceBuilder::new()
1200                .layer(Extension(server.app_state.clone()))
1201                .layer(middleware::from_fn(auth::validate_header)),
1202        )
1203        .route("/metrics", get(handle_metrics))
1204        .layer(Extension(server))
1205}
1206
1207pub async fn handle_websocket_request(
1208    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1209    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1210    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1211    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1212    Extension(server): Extension<Arc<Server>>,
1213    Extension(principal): Extension<Principal>,
1214    user_agent: Option<TypedHeader<UserAgent>>,
1215    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1216    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1217    ws: WebSocketUpgrade,
1218) -> axum::response::Response {
1219    if protocol_version != rpc::PROTOCOL_VERSION {
1220        return (
1221            StatusCode::UPGRADE_REQUIRED,
1222            "client must be upgraded".to_string(),
1223        )
1224            .into_response();
1225    }
1226
1227    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1228        return (
1229            StatusCode::UPGRADE_REQUIRED,
1230            "no version header found".to_string(),
1231        )
1232            .into_response();
1233    };
1234
1235    let release_channel = release_channel_header.map(|header| header.0.0);
1236
1237    if !version.can_collaborate() {
1238        return (
1239            StatusCode::UPGRADE_REQUIRED,
1240            "client must be upgraded".to_string(),
1241        )
1242            .into_response();
1243    }
1244
1245    let socket_address = socket_address.to_string();
1246
1247    // Acquire connection guard before WebSocket upgrade
1248    let connection_guard = match ConnectionGuard::try_acquire() {
1249        Ok(guard) => guard,
1250        Err(()) => {
1251            return (
1252                StatusCode::SERVICE_UNAVAILABLE,
1253                "Too many concurrent connections",
1254            )
1255                .into_response();
1256        }
1257    };
1258
1259    ws.on_upgrade(move |socket| {
1260        let socket = socket
1261            .map_ok(to_tungstenite_message)
1262            .err_into()
1263            .with(|message| async move { to_axum_message(message) });
1264        let connection = Connection::new(Box::pin(socket));
1265        async move {
1266            server
1267                .handle_connection(
1268                    connection,
1269                    socket_address,
1270                    principal,
1271                    version,
1272                    release_channel,
1273                    user_agent.map(|header| header.to_string()),
1274                    country_code_header.map(|header| header.to_string()),
1275                    system_id_header.map(|header| header.to_string()),
1276                    None,
1277                    Executor::Production,
1278                    Some(connection_guard),
1279                )
1280                .await;
1281        }
1282    })
1283}
1284
1285pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1286    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1287    let connections_metric = CONNECTIONS_METRIC
1288        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1289
1290    let connections = server
1291        .connection_pool
1292        .lock()
1293        .connections()
1294        .filter(|connection| !connection.admin)
1295        .count();
1296    connections_metric.set(connections as _);
1297
1298    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1299    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1300        register_int_gauge!(
1301            "shared_projects",
1302            "number of open projects with one or more guests"
1303        )
1304        .unwrap()
1305    });
1306
1307    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1308    shared_projects_metric.set(shared_projects as _);
1309
1310    let encoder = prometheus::TextEncoder::new();
1311    let metric_families = prometheus::gather();
1312    let encoded_metrics = encoder
1313        .encode_to_string(&metric_families)
1314        .map_err(|err| anyhow!("{err}"))?;
1315    Ok(encoded_metrics)
1316}
1317
1318#[instrument(err, skip(executor))]
1319async fn connection_lost(
1320    session: Session,
1321    mut teardown: watch::Receiver<bool>,
1322    executor: Executor,
1323) -> Result<()> {
1324    session.peer.disconnect(session.connection_id);
1325    session
1326        .connection_pool()
1327        .await
1328        .remove_connection(session.connection_id)?;
1329
1330    session
1331        .db()
1332        .await
1333        .connection_lost(session.connection_id)
1334        .await
1335        .trace_err();
1336
1337    futures::select_biased! {
1338        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1339
1340            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1341            leave_room_for_session(&session, session.connection_id).await.trace_err();
1342            leave_channel_buffers_for_session(&session)
1343                .await
1344                .trace_err();
1345
1346            if !session
1347                .connection_pool()
1348                .await
1349                .is_user_online(session.user_id())
1350            {
1351                let db = session.db().await;
1352                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1353                    room_updated(&room, &session.peer);
1354                }
1355            }
1356
1357            update_user_contacts(session.user_id(), &session).await?;
1358        },
1359        _ = teardown.changed().fuse() => {}
1360    }
1361
1362    Ok(())
1363}
1364
1365/// Acknowledges a ping from a client, used to keep the connection alive.
1366async fn ping(
1367    _: proto::Ping,
1368    response: Response<proto::Ping>,
1369    _session: MessageContext,
1370) -> Result<()> {
1371    response.send(proto::Ack {})?;
1372    Ok(())
1373}
1374
1375/// Creates a new room for calling (outside of channels)
1376async fn create_room(
1377    _request: proto::CreateRoom,
1378    response: Response<proto::CreateRoom>,
1379    session: MessageContext,
1380) -> Result<()> {
1381    let livekit_room = nanoid::nanoid!(30);
1382
1383    let live_kit_connection_info = util::maybe!(async {
1384        let live_kit = session.app_state.livekit_client.as_ref();
1385        let live_kit = live_kit?;
1386        let user_id = session.user_id().to_string();
1387
1388        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1389
1390        Some(proto::LiveKitConnectionInfo {
1391            server_url: live_kit.url().into(),
1392            token,
1393            can_publish: true,
1394        })
1395    })
1396    .await;
1397
1398    let room = session
1399        .db()
1400        .await
1401        .create_room(session.user_id(), session.connection_id, &livekit_room)
1402        .await?;
1403
1404    response.send(proto::CreateRoomResponse {
1405        room: Some(room.clone()),
1406        live_kit_connection_info,
1407    })?;
1408
1409    update_user_contacts(session.user_id(), &session).await?;
1410    Ok(())
1411}
1412
1413/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1414async fn join_room(
1415    request: proto::JoinRoom,
1416    response: Response<proto::JoinRoom>,
1417    session: MessageContext,
1418) -> Result<()> {
1419    let room_id = RoomId::from_proto(request.id);
1420
1421    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1422
1423    if let Some(channel_id) = channel_id {
1424        return join_channel_internal(channel_id, Box::new(response), session).await;
1425    }
1426
1427    let joined_room = {
1428        let room = session
1429            .db()
1430            .await
1431            .join_room(room_id, session.user_id(), session.connection_id)
1432            .await?;
1433        room_updated(&room.room, &session.peer);
1434        room.into_inner()
1435    };
1436
1437    for connection_id in session
1438        .connection_pool()
1439        .await
1440        .user_connection_ids(session.user_id())
1441    {
1442        session
1443            .peer
1444            .send(
1445                connection_id,
1446                proto::CallCanceled {
1447                    room_id: room_id.to_proto(),
1448                },
1449            )
1450            .trace_err();
1451    }
1452
1453    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1454    {
1455        live_kit
1456            .room_token(
1457                &joined_room.room.livekit_room,
1458                &session.user_id().to_string(),
1459            )
1460            .trace_err()
1461            .map(|token| proto::LiveKitConnectionInfo {
1462                server_url: live_kit.url().into(),
1463                token,
1464                can_publish: true,
1465            })
1466    } else {
1467        None
1468    };
1469
1470    response.send(proto::JoinRoomResponse {
1471        room: Some(joined_room.room),
1472        channel_id: None,
1473        live_kit_connection_info,
1474    })?;
1475
1476    update_user_contacts(session.user_id(), &session).await?;
1477    Ok(())
1478}
1479
1480/// Rejoin room is used to reconnect to a room after connection errors.
1481async fn rejoin_room(
1482    request: proto::RejoinRoom,
1483    response: Response<proto::RejoinRoom>,
1484    session: MessageContext,
1485) -> Result<()> {
1486    let room;
1487    let channel;
1488    {
1489        let mut rejoined_room = session
1490            .db()
1491            .await
1492            .rejoin_room(request, session.user_id(), session.connection_id)
1493            .await?;
1494
1495        response.send(proto::RejoinRoomResponse {
1496            room: Some(rejoined_room.room.clone()),
1497            reshared_projects: rejoined_room
1498                .reshared_projects
1499                .iter()
1500                .map(|project| proto::ResharedProject {
1501                    id: project.id.to_proto(),
1502                    collaborators: project
1503                        .collaborators
1504                        .iter()
1505                        .map(|collaborator| collaborator.to_proto())
1506                        .collect(),
1507                })
1508                .collect(),
1509            rejoined_projects: rejoined_room
1510                .rejoined_projects
1511                .iter()
1512                .map(|rejoined_project| rejoined_project.to_proto())
1513                .collect(),
1514        })?;
1515        room_updated(&rejoined_room.room, &session.peer);
1516
1517        for project in &rejoined_room.reshared_projects {
1518            for collaborator in &project.collaborators {
1519                session
1520                    .peer
1521                    .send(
1522                        collaborator.connection_id,
1523                        proto::UpdateProjectCollaborator {
1524                            project_id: project.id.to_proto(),
1525                            old_peer_id: Some(project.old_connection_id.into()),
1526                            new_peer_id: Some(session.connection_id.into()),
1527                        },
1528                    )
1529                    .trace_err();
1530            }
1531
1532            broadcast(
1533                Some(session.connection_id),
1534                project
1535                    .collaborators
1536                    .iter()
1537                    .map(|collaborator| collaborator.connection_id),
1538                |connection_id| {
1539                    session.peer.forward_send(
1540                        session.connection_id,
1541                        connection_id,
1542                        proto::UpdateProject {
1543                            project_id: project.id.to_proto(),
1544                            worktrees: project.worktrees.clone(),
1545                        },
1546                    )
1547                },
1548            );
1549        }
1550
1551        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1552
1553        let rejoined_room = rejoined_room.into_inner();
1554
1555        room = rejoined_room.room;
1556        channel = rejoined_room.channel;
1557    }
1558
1559    if let Some(channel) = channel {
1560        channel_updated(
1561            &channel,
1562            &room,
1563            &session.peer,
1564            &*session.connection_pool().await,
1565        );
1566    }
1567
1568    update_user_contacts(session.user_id(), &session).await?;
1569    Ok(())
1570}
1571
1572fn notify_rejoined_projects(
1573    rejoined_projects: &mut Vec<RejoinedProject>,
1574    session: &Session,
1575) -> Result<()> {
1576    for project in rejoined_projects.iter() {
1577        for collaborator in &project.collaborators {
1578            session
1579                .peer
1580                .send(
1581                    collaborator.connection_id,
1582                    proto::UpdateProjectCollaborator {
1583                        project_id: project.id.to_proto(),
1584                        old_peer_id: Some(project.old_connection_id.into()),
1585                        new_peer_id: Some(session.connection_id.into()),
1586                    },
1587                )
1588                .trace_err();
1589        }
1590    }
1591
1592    for project in rejoined_projects {
1593        for worktree in mem::take(&mut project.worktrees) {
1594            // Stream this worktree's entries.
1595            let message = proto::UpdateWorktree {
1596                project_id: project.id.to_proto(),
1597                worktree_id: worktree.id,
1598                abs_path: worktree.abs_path.clone(),
1599                root_name: worktree.root_name,
1600                updated_entries: worktree.updated_entries,
1601                removed_entries: worktree.removed_entries,
1602                scan_id: worktree.scan_id,
1603                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1604                updated_repositories: worktree.updated_repositories,
1605                removed_repositories: worktree.removed_repositories,
1606            };
1607            for update in proto::split_worktree_update(message) {
1608                session.peer.send(session.connection_id, update)?;
1609            }
1610
1611            // Stream this worktree's diagnostics.
1612            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1613            if let Some(summary) = worktree_diagnostics.next() {
1614                let message = proto::UpdateDiagnosticSummary {
1615                    project_id: project.id.to_proto(),
1616                    worktree_id: worktree.id,
1617                    summary: Some(summary),
1618                    more_summaries: worktree_diagnostics.collect(),
1619                };
1620                session.peer.send(session.connection_id, message)?;
1621            }
1622
1623            for settings_file in worktree.settings_files {
1624                session.peer.send(
1625                    session.connection_id,
1626                    proto::UpdateWorktreeSettings {
1627                        project_id: project.id.to_proto(),
1628                        worktree_id: worktree.id,
1629                        path: settings_file.path,
1630                        content: Some(settings_file.content),
1631                        kind: Some(settings_file.kind.to_proto().into()),
1632                    },
1633                )?;
1634            }
1635        }
1636
1637        for repository in mem::take(&mut project.updated_repositories) {
1638            for update in split_repository_update(repository) {
1639                session.peer.send(session.connection_id, update)?;
1640            }
1641        }
1642
1643        for id in mem::take(&mut project.removed_repositories) {
1644            session.peer.send(
1645                session.connection_id,
1646                proto::RemoveRepository {
1647                    project_id: project.id.to_proto(),
1648                    id,
1649                },
1650            )?;
1651        }
1652    }
1653
1654    Ok(())
1655}
1656
1657/// leave room disconnects from the room.
1658async fn leave_room(
1659    _: proto::LeaveRoom,
1660    response: Response<proto::LeaveRoom>,
1661    session: MessageContext,
1662) -> Result<()> {
1663    leave_room_for_session(&session, session.connection_id).await?;
1664    response.send(proto::Ack {})?;
1665    Ok(())
1666}
1667
1668/// Updates the permissions of someone else in the room.
1669async fn set_room_participant_role(
1670    request: proto::SetRoomParticipantRole,
1671    response: Response<proto::SetRoomParticipantRole>,
1672    session: MessageContext,
1673) -> Result<()> {
1674    let user_id = UserId::from_proto(request.user_id);
1675    let role = ChannelRole::from(request.role());
1676
1677    let (livekit_room, can_publish) = {
1678        let room = session
1679            .db()
1680            .await
1681            .set_room_participant_role(
1682                session.user_id(),
1683                RoomId::from_proto(request.room_id),
1684                user_id,
1685                role,
1686            )
1687            .await?;
1688
1689        let livekit_room = room.livekit_room.clone();
1690        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1691        room_updated(&room, &session.peer);
1692        (livekit_room, can_publish)
1693    };
1694
1695    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1696        live_kit
1697            .update_participant(
1698                livekit_room.clone(),
1699                request.user_id.to_string(),
1700                livekit_api::proto::ParticipantPermission {
1701                    can_subscribe: true,
1702                    can_publish,
1703                    can_publish_data: can_publish,
1704                    hidden: false,
1705                    recorder: false,
1706                },
1707            )
1708            .await
1709            .trace_err();
1710    }
1711
1712    response.send(proto::Ack {})?;
1713    Ok(())
1714}
1715
1716/// Call someone else into the current room
1717async fn call(
1718    request: proto::Call,
1719    response: Response<proto::Call>,
1720    session: MessageContext,
1721) -> Result<()> {
1722    let room_id = RoomId::from_proto(request.room_id);
1723    let calling_user_id = session.user_id();
1724    let calling_connection_id = session.connection_id;
1725    let called_user_id = UserId::from_proto(request.called_user_id);
1726    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1727    if !session
1728        .db()
1729        .await
1730        .has_contact(calling_user_id, called_user_id)
1731        .await?
1732    {
1733        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1734    }
1735
1736    let incoming_call = {
1737        let (room, incoming_call) = &mut *session
1738            .db()
1739            .await
1740            .call(
1741                room_id,
1742                calling_user_id,
1743                calling_connection_id,
1744                called_user_id,
1745                initial_project_id,
1746            )
1747            .await?;
1748        room_updated(room, &session.peer);
1749        mem::take(incoming_call)
1750    };
1751    update_user_contacts(called_user_id, &session).await?;
1752
1753    let mut calls = session
1754        .connection_pool()
1755        .await
1756        .user_connection_ids(called_user_id)
1757        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1758        .collect::<FuturesUnordered<_>>();
1759
1760    while let Some(call_response) = calls.next().await {
1761        match call_response.as_ref() {
1762            Ok(_) => {
1763                response.send(proto::Ack {})?;
1764                return Ok(());
1765            }
1766            Err(_) => {
1767                call_response.trace_err();
1768            }
1769        }
1770    }
1771
1772    {
1773        let room = session
1774            .db()
1775            .await
1776            .call_failed(room_id, called_user_id)
1777            .await?;
1778        room_updated(&room, &session.peer);
1779    }
1780    update_user_contacts(called_user_id, &session).await?;
1781
1782    Err(anyhow!("failed to ring user"))?
1783}
1784
1785/// Cancel an outgoing call.
1786async fn cancel_call(
1787    request: proto::CancelCall,
1788    response: Response<proto::CancelCall>,
1789    session: MessageContext,
1790) -> Result<()> {
1791    let called_user_id = UserId::from_proto(request.called_user_id);
1792    let room_id = RoomId::from_proto(request.room_id);
1793    {
1794        let room = session
1795            .db()
1796            .await
1797            .cancel_call(room_id, session.connection_id, called_user_id)
1798            .await?;
1799        room_updated(&room, &session.peer);
1800    }
1801
1802    for connection_id in session
1803        .connection_pool()
1804        .await
1805        .user_connection_ids(called_user_id)
1806    {
1807        session
1808            .peer
1809            .send(
1810                connection_id,
1811                proto::CallCanceled {
1812                    room_id: room_id.to_proto(),
1813                },
1814            )
1815            .trace_err();
1816    }
1817    response.send(proto::Ack {})?;
1818
1819    update_user_contacts(called_user_id, &session).await?;
1820    Ok(())
1821}
1822
1823/// Decline an incoming call.
1824async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1825    let room_id = RoomId::from_proto(message.room_id);
1826    {
1827        let room = session
1828            .db()
1829            .await
1830            .decline_call(Some(room_id), session.user_id())
1831            .await?
1832            .context("declining call")?;
1833        room_updated(&room, &session.peer);
1834    }
1835
1836    for connection_id in session
1837        .connection_pool()
1838        .await
1839        .user_connection_ids(session.user_id())
1840    {
1841        session
1842            .peer
1843            .send(
1844                connection_id,
1845                proto::CallCanceled {
1846                    room_id: room_id.to_proto(),
1847                },
1848            )
1849            .trace_err();
1850    }
1851    update_user_contacts(session.user_id(), &session).await?;
1852    Ok(())
1853}
1854
1855/// Updates other participants in the room with your current location.
1856async fn update_participant_location(
1857    request: proto::UpdateParticipantLocation,
1858    response: Response<proto::UpdateParticipantLocation>,
1859    session: MessageContext,
1860) -> Result<()> {
1861    let room_id = RoomId::from_proto(request.room_id);
1862    let location = request.location.context("invalid location")?;
1863
1864    let db = session.db().await;
1865    let room = db
1866        .update_room_participant_location(room_id, session.connection_id, location)
1867        .await?;
1868
1869    room_updated(&room, &session.peer);
1870    response.send(proto::Ack {})?;
1871    Ok(())
1872}
1873
1874/// Share a project into the room.
1875async fn share_project(
1876    request: proto::ShareProject,
1877    response: Response<proto::ShareProject>,
1878    session: MessageContext,
1879) -> Result<()> {
1880    let (project_id, room) = &*session
1881        .db()
1882        .await
1883        .share_project(
1884            RoomId::from_proto(request.room_id),
1885            session.connection_id,
1886            &request.worktrees,
1887            request.is_ssh_project,
1888            request.windows_paths.unwrap_or(false),
1889        )
1890        .await?;
1891    response.send(proto::ShareProjectResponse {
1892        project_id: project_id.to_proto(),
1893    })?;
1894    room_updated(room, &session.peer);
1895
1896    Ok(())
1897}
1898
1899/// Unshare a project from the room.
1900async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1901    let project_id = ProjectId::from_proto(message.project_id);
1902    unshare_project_internal(project_id, session.connection_id, &session).await
1903}
1904
1905async fn unshare_project_internal(
1906    project_id: ProjectId,
1907    connection_id: ConnectionId,
1908    session: &Session,
1909) -> Result<()> {
1910    let delete = {
1911        let room_guard = session
1912            .db()
1913            .await
1914            .unshare_project(project_id, connection_id)
1915            .await?;
1916
1917        let (delete, room, guest_connection_ids) = &*room_guard;
1918
1919        let message = proto::UnshareProject {
1920            project_id: project_id.to_proto(),
1921        };
1922
1923        broadcast(
1924            Some(connection_id),
1925            guest_connection_ids.iter().copied(),
1926            |conn_id| session.peer.send(conn_id, message.clone()),
1927        );
1928        if let Some(room) = room {
1929            room_updated(room, &session.peer);
1930        }
1931
1932        *delete
1933    };
1934
1935    if delete {
1936        let db = session.db().await;
1937        db.delete_project(project_id).await?;
1938    }
1939
1940    Ok(())
1941}
1942
1943/// Join someone elses shared project.
1944async fn join_project(
1945    request: proto::JoinProject,
1946    response: Response<proto::JoinProject>,
1947    session: MessageContext,
1948) -> Result<()> {
1949    let project_id = ProjectId::from_proto(request.project_id);
1950
1951    tracing::info!(%project_id, "join project");
1952
1953    let db = session.db().await;
1954    let (project, replica_id) = &mut *db
1955        .join_project(
1956            project_id,
1957            session.connection_id,
1958            session.user_id(),
1959            request.committer_name.clone(),
1960            request.committer_email.clone(),
1961        )
1962        .await?;
1963    drop(db);
1964    tracing::info!(%project_id, "join remote project");
1965    let collaborators = project
1966        .collaborators
1967        .iter()
1968        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1969        .map(|collaborator| collaborator.to_proto())
1970        .collect::<Vec<_>>();
1971    let project_id = project.id;
1972    let guest_user_id = session.user_id();
1973
1974    let worktrees = project
1975        .worktrees
1976        .iter()
1977        .map(|(id, worktree)| proto::WorktreeMetadata {
1978            id: *id,
1979            root_name: worktree.root_name.clone(),
1980            visible: worktree.visible,
1981            abs_path: worktree.abs_path.clone(),
1982        })
1983        .collect::<Vec<_>>();
1984
1985    let add_project_collaborator = proto::AddProjectCollaborator {
1986        project_id: project_id.to_proto(),
1987        collaborator: Some(proto::Collaborator {
1988            peer_id: Some(session.connection_id.into()),
1989            replica_id: replica_id.0 as u32,
1990            user_id: guest_user_id.to_proto(),
1991            is_host: false,
1992            committer_name: request.committer_name.clone(),
1993            committer_email: request.committer_email.clone(),
1994        }),
1995    };
1996
1997    for collaborator in &collaborators {
1998        session
1999            .peer
2000            .send(
2001                collaborator.peer_id.unwrap().into(),
2002                add_project_collaborator.clone(),
2003            )
2004            .trace_err();
2005    }
2006
2007    // First, we send the metadata associated with each worktree.
2008    let (language_servers, language_server_capabilities) = project
2009        .language_servers
2010        .clone()
2011        .into_iter()
2012        .map(|server| (server.server, server.capabilities))
2013        .unzip();
2014    response.send(proto::JoinProjectResponse {
2015        project_id: project.id.0 as u64,
2016        worktrees,
2017        replica_id: replica_id.0 as u32,
2018        collaborators,
2019        language_servers,
2020        language_server_capabilities,
2021        role: project.role.into(),
2022        windows_paths: project.path_style == PathStyle::Windows,
2023    })?;
2024
2025    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2026        // Stream this worktree's entries.
2027        let message = proto::UpdateWorktree {
2028            project_id: project_id.to_proto(),
2029            worktree_id,
2030            abs_path: worktree.abs_path.clone(),
2031            root_name: worktree.root_name,
2032            updated_entries: worktree.entries,
2033            removed_entries: Default::default(),
2034            scan_id: worktree.scan_id,
2035            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2036            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2037            removed_repositories: Default::default(),
2038        };
2039        for update in proto::split_worktree_update(message) {
2040            session.peer.send(session.connection_id, update.clone())?;
2041        }
2042
2043        // Stream this worktree's diagnostics.
2044        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2045        if let Some(summary) = worktree_diagnostics.next() {
2046            let message = proto::UpdateDiagnosticSummary {
2047                project_id: project.id.to_proto(),
2048                worktree_id: worktree.id,
2049                summary: Some(summary),
2050                more_summaries: worktree_diagnostics.collect(),
2051            };
2052            session.peer.send(session.connection_id, message)?;
2053        }
2054
2055        for settings_file in worktree.settings_files {
2056            session.peer.send(
2057                session.connection_id,
2058                proto::UpdateWorktreeSettings {
2059                    project_id: project_id.to_proto(),
2060                    worktree_id: worktree.id,
2061                    path: settings_file.path,
2062                    content: Some(settings_file.content),
2063                    kind: Some(settings_file.kind.to_proto() as i32),
2064                },
2065            )?;
2066        }
2067    }
2068
2069    for repository in mem::take(&mut project.repositories) {
2070        for update in split_repository_update(repository) {
2071            session.peer.send(session.connection_id, update)?;
2072        }
2073    }
2074
2075    for language_server in &project.language_servers {
2076        session.peer.send(
2077            session.connection_id,
2078            proto::UpdateLanguageServer {
2079                project_id: project_id.to_proto(),
2080                server_name: Some(language_server.server.name.clone()),
2081                language_server_id: language_server.server.id,
2082                variant: Some(
2083                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2084                        proto::LspDiskBasedDiagnosticsUpdated {},
2085                    ),
2086                ),
2087            },
2088        )?;
2089    }
2090
2091    Ok(())
2092}
2093
2094/// Leave someone elses shared project.
2095async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2096    let sender_id = session.connection_id;
2097    let project_id = ProjectId::from_proto(request.project_id);
2098    let db = session.db().await;
2099
2100    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2101    tracing::info!(
2102        %project_id,
2103        "leave project"
2104    );
2105
2106    project_left(project, &session);
2107    if let Some(room) = room {
2108        room_updated(room, &session.peer);
2109    }
2110
2111    Ok(())
2112}
2113
2114/// Updates other participants with changes to the project
2115async fn update_project(
2116    request: proto::UpdateProject,
2117    response: Response<proto::UpdateProject>,
2118    session: MessageContext,
2119) -> Result<()> {
2120    let project_id = ProjectId::from_proto(request.project_id);
2121    let (room, guest_connection_ids) = &*session
2122        .db()
2123        .await
2124        .update_project(project_id, session.connection_id, &request.worktrees)
2125        .await?;
2126    broadcast(
2127        Some(session.connection_id),
2128        guest_connection_ids.iter().copied(),
2129        |connection_id| {
2130            session
2131                .peer
2132                .forward_send(session.connection_id, connection_id, request.clone())
2133        },
2134    );
2135    if let Some(room) = room {
2136        room_updated(room, &session.peer);
2137    }
2138    response.send(proto::Ack {})?;
2139
2140    Ok(())
2141}
2142
2143/// Updates other participants with changes to the worktree
2144async fn update_worktree(
2145    request: proto::UpdateWorktree,
2146    response: Response<proto::UpdateWorktree>,
2147    session: MessageContext,
2148) -> Result<()> {
2149    let guest_connection_ids = session
2150        .db()
2151        .await
2152        .update_worktree(&request, session.connection_id)
2153        .await?;
2154
2155    broadcast(
2156        Some(session.connection_id),
2157        guest_connection_ids.iter().copied(),
2158        |connection_id| {
2159            session
2160                .peer
2161                .forward_send(session.connection_id, connection_id, request.clone())
2162        },
2163    );
2164    response.send(proto::Ack {})?;
2165    Ok(())
2166}
2167
2168async fn update_repository(
2169    request: proto::UpdateRepository,
2170    response: Response<proto::UpdateRepository>,
2171    session: MessageContext,
2172) -> Result<()> {
2173    let guest_connection_ids = session
2174        .db()
2175        .await
2176        .update_repository(&request, session.connection_id)
2177        .await?;
2178
2179    broadcast(
2180        Some(session.connection_id),
2181        guest_connection_ids.iter().copied(),
2182        |connection_id| {
2183            session
2184                .peer
2185                .forward_send(session.connection_id, connection_id, request.clone())
2186        },
2187    );
2188    response.send(proto::Ack {})?;
2189    Ok(())
2190}
2191
2192async fn remove_repository(
2193    request: proto::RemoveRepository,
2194    response: Response<proto::RemoveRepository>,
2195    session: MessageContext,
2196) -> Result<()> {
2197    let guest_connection_ids = session
2198        .db()
2199        .await
2200        .remove_repository(&request, session.connection_id)
2201        .await?;
2202
2203    broadcast(
2204        Some(session.connection_id),
2205        guest_connection_ids.iter().copied(),
2206        |connection_id| {
2207            session
2208                .peer
2209                .forward_send(session.connection_id, connection_id, request.clone())
2210        },
2211    );
2212    response.send(proto::Ack {})?;
2213    Ok(())
2214}
2215
2216/// Updates other participants with changes to the diagnostics
2217async fn update_diagnostic_summary(
2218    message: proto::UpdateDiagnosticSummary,
2219    session: MessageContext,
2220) -> Result<()> {
2221    let guest_connection_ids = session
2222        .db()
2223        .await
2224        .update_diagnostic_summary(&message, session.connection_id)
2225        .await?;
2226
2227    broadcast(
2228        Some(session.connection_id),
2229        guest_connection_ids.iter().copied(),
2230        |connection_id| {
2231            session
2232                .peer
2233                .forward_send(session.connection_id, connection_id, message.clone())
2234        },
2235    );
2236
2237    Ok(())
2238}
2239
2240/// Updates other participants with changes to the worktree settings
2241async fn update_worktree_settings(
2242    message: proto::UpdateWorktreeSettings,
2243    session: MessageContext,
2244) -> Result<()> {
2245    let guest_connection_ids = session
2246        .db()
2247        .await
2248        .update_worktree_settings(&message, session.connection_id)
2249        .await?;
2250
2251    broadcast(
2252        Some(session.connection_id),
2253        guest_connection_ids.iter().copied(),
2254        |connection_id| {
2255            session
2256                .peer
2257                .forward_send(session.connection_id, connection_id, message.clone())
2258        },
2259    );
2260
2261    Ok(())
2262}
2263
2264/// Notify other participants that a language server has started.
2265async fn start_language_server(
2266    request: proto::StartLanguageServer,
2267    session: MessageContext,
2268) -> Result<()> {
2269    let guest_connection_ids = session
2270        .db()
2271        .await
2272        .start_language_server(&request, session.connection_id)
2273        .await?;
2274
2275    broadcast(
2276        Some(session.connection_id),
2277        guest_connection_ids.iter().copied(),
2278        |connection_id| {
2279            session
2280                .peer
2281                .forward_send(session.connection_id, connection_id, request.clone())
2282        },
2283    );
2284    Ok(())
2285}
2286
2287/// Notify other participants that a language server has changed.
2288async fn update_language_server(
2289    request: proto::UpdateLanguageServer,
2290    session: MessageContext,
2291) -> Result<()> {
2292    let project_id = ProjectId::from_proto(request.project_id);
2293    let db = session.db().await;
2294
2295    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2296        && let Some(capabilities) = update.capabilities.clone()
2297    {
2298        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2299            .await?;
2300    }
2301
2302    let project_connection_ids = db
2303        .project_connection_ids(project_id, session.connection_id, true)
2304        .await?;
2305    broadcast(
2306        Some(session.connection_id),
2307        project_connection_ids.iter().copied(),
2308        |connection_id| {
2309            session
2310                .peer
2311                .forward_send(session.connection_id, connection_id, request.clone())
2312        },
2313    );
2314    Ok(())
2315}
2316
2317/// forward a project request to the host. These requests should be read only
2318/// as guests are allowed to send them.
2319async fn forward_read_only_project_request<T>(
2320    request: T,
2321    response: Response<T>,
2322    session: MessageContext,
2323) -> Result<()>
2324where
2325    T: EntityMessage + RequestMessage,
2326{
2327    let project_id = ProjectId::from_proto(request.remote_entity_id());
2328    let host_connection_id = session
2329        .db()
2330        .await
2331        .host_for_read_only_project_request(project_id, session.connection_id)
2332        .await?;
2333    let payload = session.forward_request(host_connection_id, request).await?;
2334    response.send(payload)?;
2335    Ok(())
2336}
2337
2338/// forward a project request to the host. These requests are disallowed
2339/// for guests.
2340async fn forward_mutating_project_request<T>(
2341    request: T,
2342    response: Response<T>,
2343    session: MessageContext,
2344) -> Result<()>
2345where
2346    T: EntityMessage + RequestMessage,
2347{
2348    let project_id = ProjectId::from_proto(request.remote_entity_id());
2349
2350    let host_connection_id = session
2351        .db()
2352        .await
2353        .host_for_mutating_project_request(project_id, session.connection_id)
2354        .await?;
2355    let payload = session.forward_request(host_connection_id, request).await?;
2356    response.send(payload)?;
2357    Ok(())
2358}
2359
2360async fn lsp_query(
2361    request: proto::LspQuery,
2362    response: Response<proto::LspQuery>,
2363    session: MessageContext,
2364) -> Result<()> {
2365    let (name, should_write) = request.query_name_and_write_permissions();
2366    tracing::Span::current().record("lsp_query_request", name);
2367    tracing::info!("lsp_query message received");
2368    if should_write {
2369        forward_mutating_project_request(request, response, session).await
2370    } else {
2371        forward_read_only_project_request(request, response, session).await
2372    }
2373}
2374
2375/// Notify other participants that a new buffer has been created
2376async fn create_buffer_for_peer(
2377    request: proto::CreateBufferForPeer,
2378    session: MessageContext,
2379) -> Result<()> {
2380    session
2381        .db()
2382        .await
2383        .check_user_is_project_host(
2384            ProjectId::from_proto(request.project_id),
2385            session.connection_id,
2386        )
2387        .await?;
2388    let peer_id = request.peer_id.context("invalid peer id")?;
2389    session
2390        .peer
2391        .forward_send(session.connection_id, peer_id.into(), request)?;
2392    Ok(())
2393}
2394
2395/// Notify other participants that a new image has been created
2396async fn create_image_for_peer(
2397    request: proto::CreateImageForPeer,
2398    session: MessageContext,
2399) -> Result<()> {
2400    session
2401        .db()
2402        .await
2403        .check_user_is_project_host(
2404            ProjectId::from_proto(request.project_id),
2405            session.connection_id,
2406        )
2407        .await?;
2408    let peer_id = request.peer_id.context("invalid peer id")?;
2409    session
2410        .peer
2411        .forward_send(session.connection_id, peer_id.into(), request)?;
2412    Ok(())
2413}
2414
2415/// Notify other participants that a buffer has been updated. This is
2416/// allowed for guests as long as the update is limited to selections.
2417async fn update_buffer(
2418    request: proto::UpdateBuffer,
2419    response: Response<proto::UpdateBuffer>,
2420    session: MessageContext,
2421) -> Result<()> {
2422    let project_id = ProjectId::from_proto(request.project_id);
2423    let mut capability = Capability::ReadOnly;
2424
2425    for op in request.operations.iter() {
2426        match op.variant {
2427            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2428            Some(_) => capability = Capability::ReadWrite,
2429        }
2430    }
2431
2432    let host = {
2433        let guard = session
2434            .db()
2435            .await
2436            .connections_for_buffer_update(project_id, session.connection_id, capability)
2437            .await?;
2438
2439        let (host, guests) = &*guard;
2440
2441        broadcast(
2442            Some(session.connection_id),
2443            guests.clone(),
2444            |connection_id| {
2445                session
2446                    .peer
2447                    .forward_send(session.connection_id, connection_id, request.clone())
2448            },
2449        );
2450
2451        *host
2452    };
2453
2454    if host != session.connection_id {
2455        session.forward_request(host, request.clone()).await?;
2456    }
2457
2458    response.send(proto::Ack {})?;
2459    Ok(())
2460}
2461
2462async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2463    let project_id = ProjectId::from_proto(message.project_id);
2464
2465    let operation = message.operation.as_ref().context("invalid operation")?;
2466    let capability = match operation.variant.as_ref() {
2467        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2468            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2469                match buffer_op.variant {
2470                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2471                        Capability::ReadOnly
2472                    }
2473                    _ => Capability::ReadWrite,
2474                }
2475            } else {
2476                Capability::ReadWrite
2477            }
2478        }
2479        Some(_) => Capability::ReadWrite,
2480        None => Capability::ReadOnly,
2481    };
2482
2483    let guard = session
2484        .db()
2485        .await
2486        .connections_for_buffer_update(project_id, session.connection_id, capability)
2487        .await?;
2488
2489    let (host, guests) = &*guard;
2490
2491    broadcast(
2492        Some(session.connection_id),
2493        guests.iter().chain([host]).copied(),
2494        |connection_id| {
2495            session
2496                .peer
2497                .forward_send(session.connection_id, connection_id, message.clone())
2498        },
2499    );
2500
2501    Ok(())
2502}
2503
2504/// Notify other participants that a project has been updated.
2505async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2506    request: T,
2507    session: MessageContext,
2508) -> Result<()> {
2509    let project_id = ProjectId::from_proto(request.remote_entity_id());
2510    let project_connection_ids = session
2511        .db()
2512        .await
2513        .project_connection_ids(project_id, session.connection_id, false)
2514        .await?;
2515
2516    broadcast(
2517        Some(session.connection_id),
2518        project_connection_ids.iter().copied(),
2519        |connection_id| {
2520            session
2521                .peer
2522                .forward_send(session.connection_id, connection_id, request.clone())
2523        },
2524    );
2525    Ok(())
2526}
2527
2528/// Start following another user in a call.
2529async fn follow(
2530    request: proto::Follow,
2531    response: Response<proto::Follow>,
2532    session: MessageContext,
2533) -> Result<()> {
2534    let room_id = RoomId::from_proto(request.room_id);
2535    let project_id = request.project_id.map(ProjectId::from_proto);
2536    let leader_id = request.leader_id.context("invalid leader id")?.into();
2537    let follower_id = session.connection_id;
2538
2539    session
2540        .db()
2541        .await
2542        .check_room_participants(room_id, leader_id, session.connection_id)
2543        .await?;
2544
2545    let response_payload = session.forward_request(leader_id, request).await?;
2546    response.send(response_payload)?;
2547
2548    if let Some(project_id) = project_id {
2549        let room = session
2550            .db()
2551            .await
2552            .follow(room_id, project_id, leader_id, follower_id)
2553            .await?;
2554        room_updated(&room, &session.peer);
2555    }
2556
2557    Ok(())
2558}
2559
2560/// Stop following another user in a call.
2561async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2562    let room_id = RoomId::from_proto(request.room_id);
2563    let project_id = request.project_id.map(ProjectId::from_proto);
2564    let leader_id = request.leader_id.context("invalid leader id")?.into();
2565    let follower_id = session.connection_id;
2566
2567    session
2568        .db()
2569        .await
2570        .check_room_participants(room_id, leader_id, session.connection_id)
2571        .await?;
2572
2573    session
2574        .peer
2575        .forward_send(session.connection_id, leader_id, request)?;
2576
2577    if let Some(project_id) = project_id {
2578        let room = session
2579            .db()
2580            .await
2581            .unfollow(room_id, project_id, leader_id, follower_id)
2582            .await?;
2583        room_updated(&room, &session.peer);
2584    }
2585
2586    Ok(())
2587}
2588
2589/// Notify everyone following you of your current location.
2590async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2591    let room_id = RoomId::from_proto(request.room_id);
2592    let database = session.db.lock().await;
2593
2594    let connection_ids = if let Some(project_id) = request.project_id {
2595        let project_id = ProjectId::from_proto(project_id);
2596        database
2597            .project_connection_ids(project_id, session.connection_id, true)
2598            .await?
2599    } else {
2600        database
2601            .room_connection_ids(room_id, session.connection_id)
2602            .await?
2603    };
2604
2605    // For now, don't send view update messages back to that view's current leader.
2606    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2607        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2608        _ => None,
2609    });
2610
2611    for connection_id in connection_ids.iter().cloned() {
2612        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2613            session
2614                .peer
2615                .forward_send(session.connection_id, connection_id, request.clone())?;
2616        }
2617    }
2618    Ok(())
2619}
2620
2621/// Get public data about users.
2622async fn get_users(
2623    request: proto::GetUsers,
2624    response: Response<proto::GetUsers>,
2625    session: MessageContext,
2626) -> Result<()> {
2627    let user_ids = request
2628        .user_ids
2629        .into_iter()
2630        .map(UserId::from_proto)
2631        .collect();
2632    let users = session
2633        .db()
2634        .await
2635        .get_users_by_ids(user_ids)
2636        .await?
2637        .into_iter()
2638        .map(|user| proto::User {
2639            id: user.id.to_proto(),
2640            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2641            github_login: user.github_login,
2642            name: user.name,
2643        })
2644        .collect();
2645    response.send(proto::UsersResponse { users })?;
2646    Ok(())
2647}
2648
2649/// Search for users (to invite) buy Github login
2650async fn fuzzy_search_users(
2651    request: proto::FuzzySearchUsers,
2652    response: Response<proto::FuzzySearchUsers>,
2653    session: MessageContext,
2654) -> Result<()> {
2655    let query = request.query;
2656    let users = match query.len() {
2657        0 => vec![],
2658        1 | 2 => session
2659            .db()
2660            .await
2661            .get_user_by_github_login(&query)
2662            .await?
2663            .into_iter()
2664            .collect(),
2665        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2666    };
2667    let users = users
2668        .into_iter()
2669        .filter(|user| user.id != session.user_id())
2670        .map(|user| proto::User {
2671            id: user.id.to_proto(),
2672            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2673            github_login: user.github_login,
2674            name: user.name,
2675        })
2676        .collect();
2677    response.send(proto::UsersResponse { users })?;
2678    Ok(())
2679}
2680
2681/// Send a contact request to another user.
2682async fn request_contact(
2683    request: proto::RequestContact,
2684    response: Response<proto::RequestContact>,
2685    session: MessageContext,
2686) -> Result<()> {
2687    let requester_id = session.user_id();
2688    let responder_id = UserId::from_proto(request.responder_id);
2689    if requester_id == responder_id {
2690        return Err(anyhow!("cannot add yourself as a contact"))?;
2691    }
2692
2693    let notifications = session
2694        .db()
2695        .await
2696        .send_contact_request(requester_id, responder_id)
2697        .await?;
2698
2699    // Update outgoing contact requests of requester
2700    let mut update = proto::UpdateContacts::default();
2701    update.outgoing_requests.push(responder_id.to_proto());
2702    for connection_id in session
2703        .connection_pool()
2704        .await
2705        .user_connection_ids(requester_id)
2706    {
2707        session.peer.send(connection_id, update.clone())?;
2708    }
2709
2710    // Update incoming contact requests of responder
2711    let mut update = proto::UpdateContacts::default();
2712    update
2713        .incoming_requests
2714        .push(proto::IncomingContactRequest {
2715            requester_id: requester_id.to_proto(),
2716        });
2717    let connection_pool = session.connection_pool().await;
2718    for connection_id in connection_pool.user_connection_ids(responder_id) {
2719        session.peer.send(connection_id, update.clone())?;
2720    }
2721
2722    send_notifications(&connection_pool, &session.peer, notifications);
2723
2724    response.send(proto::Ack {})?;
2725    Ok(())
2726}
2727
2728/// Accept or decline a contact request
2729async fn respond_to_contact_request(
2730    request: proto::RespondToContactRequest,
2731    response: Response<proto::RespondToContactRequest>,
2732    session: MessageContext,
2733) -> Result<()> {
2734    let responder_id = session.user_id();
2735    let requester_id = UserId::from_proto(request.requester_id);
2736    let db = session.db().await;
2737    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2738        db.dismiss_contact_notification(responder_id, requester_id)
2739            .await?;
2740    } else {
2741        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2742
2743        let notifications = db
2744            .respond_to_contact_request(responder_id, requester_id, accept)
2745            .await?;
2746        let requester_busy = db.is_user_busy(requester_id).await?;
2747        let responder_busy = db.is_user_busy(responder_id).await?;
2748
2749        let pool = session.connection_pool().await;
2750        // Update responder with new contact
2751        let mut update = proto::UpdateContacts::default();
2752        if accept {
2753            update
2754                .contacts
2755                .push(contact_for_user(requester_id, requester_busy, &pool));
2756        }
2757        update
2758            .remove_incoming_requests
2759            .push(requester_id.to_proto());
2760        for connection_id in pool.user_connection_ids(responder_id) {
2761            session.peer.send(connection_id, update.clone())?;
2762        }
2763
2764        // Update requester with new contact
2765        let mut update = proto::UpdateContacts::default();
2766        if accept {
2767            update
2768                .contacts
2769                .push(contact_for_user(responder_id, responder_busy, &pool));
2770        }
2771        update
2772            .remove_outgoing_requests
2773            .push(responder_id.to_proto());
2774
2775        for connection_id in pool.user_connection_ids(requester_id) {
2776            session.peer.send(connection_id, update.clone())?;
2777        }
2778
2779        send_notifications(&pool, &session.peer, notifications);
2780    }
2781
2782    response.send(proto::Ack {})?;
2783    Ok(())
2784}
2785
2786/// Remove a contact.
2787async fn remove_contact(
2788    request: proto::RemoveContact,
2789    response: Response<proto::RemoveContact>,
2790    session: MessageContext,
2791) -> Result<()> {
2792    let requester_id = session.user_id();
2793    let responder_id = UserId::from_proto(request.user_id);
2794    let db = session.db().await;
2795    let (contact_accepted, deleted_notification_id) =
2796        db.remove_contact(requester_id, responder_id).await?;
2797
2798    let pool = session.connection_pool().await;
2799    // Update outgoing contact requests of requester
2800    let mut update = proto::UpdateContacts::default();
2801    if contact_accepted {
2802        update.remove_contacts.push(responder_id.to_proto());
2803    } else {
2804        update
2805            .remove_outgoing_requests
2806            .push(responder_id.to_proto());
2807    }
2808    for connection_id in pool.user_connection_ids(requester_id) {
2809        session.peer.send(connection_id, update.clone())?;
2810    }
2811
2812    // Update incoming contact requests of responder
2813    let mut update = proto::UpdateContacts::default();
2814    if contact_accepted {
2815        update.remove_contacts.push(requester_id.to_proto());
2816    } else {
2817        update
2818            .remove_incoming_requests
2819            .push(requester_id.to_proto());
2820    }
2821    for connection_id in pool.user_connection_ids(responder_id) {
2822        session.peer.send(connection_id, update.clone())?;
2823        if let Some(notification_id) = deleted_notification_id {
2824            session.peer.send(
2825                connection_id,
2826                proto::DeleteNotification {
2827                    notification_id: notification_id.to_proto(),
2828                },
2829            )?;
2830        }
2831    }
2832
2833    response.send(proto::Ack {})?;
2834    Ok(())
2835}
2836
2837fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2838    version.0.minor < 139
2839}
2840
2841async fn subscribe_to_channels(
2842    _: proto::SubscribeToChannels,
2843    session: MessageContext,
2844) -> Result<()> {
2845    subscribe_user_to_channels(session.user_id(), &session).await?;
2846    Ok(())
2847}
2848
2849async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2850    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2851    let mut pool = session.connection_pool().await;
2852    for membership in &channels_for_user.channel_memberships {
2853        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2854    }
2855    session.peer.send(
2856        session.connection_id,
2857        build_update_user_channels(&channels_for_user),
2858    )?;
2859    session.peer.send(
2860        session.connection_id,
2861        build_channels_update(channels_for_user),
2862    )?;
2863    Ok(())
2864}
2865
2866/// Creates a new channel.
2867async fn create_channel(
2868    request: proto::CreateChannel,
2869    response: Response<proto::CreateChannel>,
2870    session: MessageContext,
2871) -> Result<()> {
2872    let db = session.db().await;
2873
2874    let parent_id = request.parent_id.map(ChannelId::from_proto);
2875    let (channel, membership) = db
2876        .create_channel(&request.name, parent_id, session.user_id())
2877        .await?;
2878
2879    let root_id = channel.root_id();
2880    let channel = Channel::from_model(channel);
2881
2882    response.send(proto::CreateChannelResponse {
2883        channel: Some(channel.to_proto()),
2884        parent_id: request.parent_id,
2885    })?;
2886
2887    let mut connection_pool = session.connection_pool().await;
2888    if let Some(membership) = membership {
2889        connection_pool.subscribe_to_channel(
2890            membership.user_id,
2891            membership.channel_id,
2892            membership.role,
2893        );
2894        let update = proto::UpdateUserChannels {
2895            channel_memberships: vec![proto::ChannelMembership {
2896                channel_id: membership.channel_id.to_proto(),
2897                role: membership.role.into(),
2898            }],
2899            ..Default::default()
2900        };
2901        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2902            session.peer.send(connection_id, update.clone())?;
2903        }
2904    }
2905
2906    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2907        if !role.can_see_channel(channel.visibility) {
2908            continue;
2909        }
2910
2911        let update = proto::UpdateChannels {
2912            channels: vec![channel.to_proto()],
2913            ..Default::default()
2914        };
2915        session.peer.send(connection_id, update.clone())?;
2916    }
2917
2918    Ok(())
2919}
2920
2921/// Delete a channel
2922async fn delete_channel(
2923    request: proto::DeleteChannel,
2924    response: Response<proto::DeleteChannel>,
2925    session: MessageContext,
2926) -> Result<()> {
2927    let db = session.db().await;
2928
2929    let channel_id = request.channel_id;
2930    let (root_channel, removed_channels) = db
2931        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2932        .await?;
2933    response.send(proto::Ack {})?;
2934
2935    // Notify members of removed channels
2936    let mut update = proto::UpdateChannels::default();
2937    update
2938        .delete_channels
2939        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2940
2941    let connection_pool = session.connection_pool().await;
2942    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2943        session.peer.send(connection_id, update.clone())?;
2944    }
2945
2946    Ok(())
2947}
2948
2949/// Invite someone to join a channel.
2950async fn invite_channel_member(
2951    request: proto::InviteChannelMember,
2952    response: Response<proto::InviteChannelMember>,
2953    session: MessageContext,
2954) -> Result<()> {
2955    let db = session.db().await;
2956    let channel_id = ChannelId::from_proto(request.channel_id);
2957    let invitee_id = UserId::from_proto(request.user_id);
2958    let InviteMemberResult {
2959        channel,
2960        notifications,
2961    } = db
2962        .invite_channel_member(
2963            channel_id,
2964            invitee_id,
2965            session.user_id(),
2966            request.role().into(),
2967        )
2968        .await?;
2969
2970    let update = proto::UpdateChannels {
2971        channel_invitations: vec![channel.to_proto()],
2972        ..Default::default()
2973    };
2974
2975    let connection_pool = session.connection_pool().await;
2976    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2977        session.peer.send(connection_id, update.clone())?;
2978    }
2979
2980    send_notifications(&connection_pool, &session.peer, notifications);
2981
2982    response.send(proto::Ack {})?;
2983    Ok(())
2984}
2985
2986/// remove someone from a channel
2987async fn remove_channel_member(
2988    request: proto::RemoveChannelMember,
2989    response: Response<proto::RemoveChannelMember>,
2990    session: MessageContext,
2991) -> Result<()> {
2992    let db = session.db().await;
2993    let channel_id = ChannelId::from_proto(request.channel_id);
2994    let member_id = UserId::from_proto(request.user_id);
2995
2996    let RemoveChannelMemberResult {
2997        membership_update,
2998        notification_id,
2999    } = db
3000        .remove_channel_member(channel_id, member_id, session.user_id())
3001        .await?;
3002
3003    let mut connection_pool = session.connection_pool().await;
3004    notify_membership_updated(
3005        &mut connection_pool,
3006        membership_update,
3007        member_id,
3008        &session.peer,
3009    );
3010    for connection_id in connection_pool.user_connection_ids(member_id) {
3011        if let Some(notification_id) = notification_id {
3012            session
3013                .peer
3014                .send(
3015                    connection_id,
3016                    proto::DeleteNotification {
3017                        notification_id: notification_id.to_proto(),
3018                    },
3019                )
3020                .trace_err();
3021        }
3022    }
3023
3024    response.send(proto::Ack {})?;
3025    Ok(())
3026}
3027
3028/// Toggle the channel between public and private.
3029/// Care is taken to maintain the invariant that public channels only descend from public channels,
3030/// (though members-only channels can appear at any point in the hierarchy).
3031async fn set_channel_visibility(
3032    request: proto::SetChannelVisibility,
3033    response: Response<proto::SetChannelVisibility>,
3034    session: MessageContext,
3035) -> Result<()> {
3036    let db = session.db().await;
3037    let channel_id = ChannelId::from_proto(request.channel_id);
3038    let visibility = request.visibility().into();
3039
3040    let channel_model = db
3041        .set_channel_visibility(channel_id, visibility, session.user_id())
3042        .await?;
3043    let root_id = channel_model.root_id();
3044    let channel = Channel::from_model(channel_model);
3045
3046    let mut connection_pool = session.connection_pool().await;
3047    for (user_id, role) in connection_pool
3048        .channel_user_ids(root_id)
3049        .collect::<Vec<_>>()
3050        .into_iter()
3051    {
3052        let update = if role.can_see_channel(channel.visibility) {
3053            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3054            proto::UpdateChannels {
3055                channels: vec![channel.to_proto()],
3056                ..Default::default()
3057            }
3058        } else {
3059            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3060            proto::UpdateChannels {
3061                delete_channels: vec![channel.id.to_proto()],
3062                ..Default::default()
3063            }
3064        };
3065
3066        for connection_id in connection_pool.user_connection_ids(user_id) {
3067            session.peer.send(connection_id, update.clone())?;
3068        }
3069    }
3070
3071    response.send(proto::Ack {})?;
3072    Ok(())
3073}
3074
3075/// Alter the role for a user in the channel.
3076async fn set_channel_member_role(
3077    request: proto::SetChannelMemberRole,
3078    response: Response<proto::SetChannelMemberRole>,
3079    session: MessageContext,
3080) -> Result<()> {
3081    let db = session.db().await;
3082    let channel_id = ChannelId::from_proto(request.channel_id);
3083    let member_id = UserId::from_proto(request.user_id);
3084    let result = db
3085        .set_channel_member_role(
3086            channel_id,
3087            session.user_id(),
3088            member_id,
3089            request.role().into(),
3090        )
3091        .await?;
3092
3093    match result {
3094        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3095            let mut connection_pool = session.connection_pool().await;
3096            notify_membership_updated(
3097                &mut connection_pool,
3098                membership_update,
3099                member_id,
3100                &session.peer,
3101            )
3102        }
3103        db::SetMemberRoleResult::InviteUpdated(channel) => {
3104            let update = proto::UpdateChannels {
3105                channel_invitations: vec![channel.to_proto()],
3106                ..Default::default()
3107            };
3108
3109            for connection_id in session
3110                .connection_pool()
3111                .await
3112                .user_connection_ids(member_id)
3113            {
3114                session.peer.send(connection_id, update.clone())?;
3115            }
3116        }
3117    }
3118
3119    response.send(proto::Ack {})?;
3120    Ok(())
3121}
3122
3123/// Change the name of a channel
3124async fn rename_channel(
3125    request: proto::RenameChannel,
3126    response: Response<proto::RenameChannel>,
3127    session: MessageContext,
3128) -> Result<()> {
3129    let db = session.db().await;
3130    let channel_id = ChannelId::from_proto(request.channel_id);
3131    let channel_model = db
3132        .rename_channel(channel_id, session.user_id(), &request.name)
3133        .await?;
3134    let root_id = channel_model.root_id();
3135    let channel = Channel::from_model(channel_model);
3136
3137    response.send(proto::RenameChannelResponse {
3138        channel: Some(channel.to_proto()),
3139    })?;
3140
3141    let connection_pool = session.connection_pool().await;
3142    let update = proto::UpdateChannels {
3143        channels: vec![channel.to_proto()],
3144        ..Default::default()
3145    };
3146    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3147        if role.can_see_channel(channel.visibility) {
3148            session.peer.send(connection_id, update.clone())?;
3149        }
3150    }
3151
3152    Ok(())
3153}
3154
3155/// Move a channel to a new parent.
3156async fn move_channel(
3157    request: proto::MoveChannel,
3158    response: Response<proto::MoveChannel>,
3159    session: MessageContext,
3160) -> Result<()> {
3161    let channel_id = ChannelId::from_proto(request.channel_id);
3162    let to = ChannelId::from_proto(request.to);
3163
3164    let (root_id, channels) = session
3165        .db()
3166        .await
3167        .move_channel(channel_id, to, session.user_id())
3168        .await?;
3169
3170    let connection_pool = session.connection_pool().await;
3171    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3172        let channels = channels
3173            .iter()
3174            .filter_map(|channel| {
3175                if role.can_see_channel(channel.visibility) {
3176                    Some(channel.to_proto())
3177                } else {
3178                    None
3179                }
3180            })
3181            .collect::<Vec<_>>();
3182        if channels.is_empty() {
3183            continue;
3184        }
3185
3186        let update = proto::UpdateChannels {
3187            channels,
3188            ..Default::default()
3189        };
3190
3191        session.peer.send(connection_id, update.clone())?;
3192    }
3193
3194    response.send(Ack {})?;
3195    Ok(())
3196}
3197
3198async fn reorder_channel(
3199    request: proto::ReorderChannel,
3200    response: Response<proto::ReorderChannel>,
3201    session: MessageContext,
3202) -> Result<()> {
3203    let channel_id = ChannelId::from_proto(request.channel_id);
3204    let direction = request.direction();
3205
3206    let updated_channels = session
3207        .db()
3208        .await
3209        .reorder_channel(channel_id, direction, session.user_id())
3210        .await?;
3211
3212    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3213        let connection_pool = session.connection_pool().await;
3214        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3215            let channels = updated_channels
3216                .iter()
3217                .filter_map(|channel| {
3218                    if role.can_see_channel(channel.visibility) {
3219                        Some(channel.to_proto())
3220                    } else {
3221                        None
3222                    }
3223                })
3224                .collect::<Vec<_>>();
3225
3226            if channels.is_empty() {
3227                continue;
3228            }
3229
3230            let update = proto::UpdateChannels {
3231                channels,
3232                ..Default::default()
3233            };
3234
3235            session.peer.send(connection_id, update.clone())?;
3236        }
3237    }
3238
3239    response.send(Ack {})?;
3240    Ok(())
3241}
3242
3243/// Get the list of channel members
3244async fn get_channel_members(
3245    request: proto::GetChannelMembers,
3246    response: Response<proto::GetChannelMembers>,
3247    session: MessageContext,
3248) -> Result<()> {
3249    let db = session.db().await;
3250    let channel_id = ChannelId::from_proto(request.channel_id);
3251    let limit = if request.limit == 0 {
3252        u16::MAX as u64
3253    } else {
3254        request.limit
3255    };
3256    let (members, users) = db
3257        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3258        .await?;
3259    response.send(proto::GetChannelMembersResponse { members, users })?;
3260    Ok(())
3261}
3262
3263/// Accept or decline a channel invitation.
3264async fn respond_to_channel_invite(
3265    request: proto::RespondToChannelInvite,
3266    response: Response<proto::RespondToChannelInvite>,
3267    session: MessageContext,
3268) -> Result<()> {
3269    let db = session.db().await;
3270    let channel_id = ChannelId::from_proto(request.channel_id);
3271    let RespondToChannelInvite {
3272        membership_update,
3273        notifications,
3274    } = db
3275        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3276        .await?;
3277
3278    let mut connection_pool = session.connection_pool().await;
3279    if let Some(membership_update) = membership_update {
3280        notify_membership_updated(
3281            &mut connection_pool,
3282            membership_update,
3283            session.user_id(),
3284            &session.peer,
3285        );
3286    } else {
3287        let update = proto::UpdateChannels {
3288            remove_channel_invitations: vec![channel_id.to_proto()],
3289            ..Default::default()
3290        };
3291
3292        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3293            session.peer.send(connection_id, update.clone())?;
3294        }
3295    };
3296
3297    send_notifications(&connection_pool, &session.peer, notifications);
3298
3299    response.send(proto::Ack {})?;
3300
3301    Ok(())
3302}
3303
3304/// Join the channels' room
3305async fn join_channel(
3306    request: proto::JoinChannel,
3307    response: Response<proto::JoinChannel>,
3308    session: MessageContext,
3309) -> Result<()> {
3310    let channel_id = ChannelId::from_proto(request.channel_id);
3311    join_channel_internal(channel_id, Box::new(response), session).await
3312}
3313
3314trait JoinChannelInternalResponse {
3315    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3316}
3317impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3318    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3319        Response::<proto::JoinChannel>::send(self, result)
3320    }
3321}
3322impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3323    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3324        Response::<proto::JoinRoom>::send(self, result)
3325    }
3326}
3327
3328async fn join_channel_internal(
3329    channel_id: ChannelId,
3330    response: Box<impl JoinChannelInternalResponse>,
3331    session: MessageContext,
3332) -> Result<()> {
3333    let joined_room = {
3334        let mut db = session.db().await;
3335        // If zed quits without leaving the room, and the user re-opens zed before the
3336        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3337        // room they were in.
3338        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3339            tracing::info!(
3340                stale_connection_id = %connection,
3341                "cleaning up stale connection",
3342            );
3343            drop(db);
3344            leave_room_for_session(&session, connection).await?;
3345            db = session.db().await;
3346        }
3347
3348        let (joined_room, membership_updated, role) = db
3349            .join_channel(channel_id, session.user_id(), session.connection_id)
3350            .await?;
3351
3352        let live_kit_connection_info =
3353            session
3354                .app_state
3355                .livekit_client
3356                .as_ref()
3357                .and_then(|live_kit| {
3358                    let (can_publish, token) = if role == ChannelRole::Guest {
3359                        (
3360                            false,
3361                            live_kit
3362                                .guest_token(
3363                                    &joined_room.room.livekit_room,
3364                                    &session.user_id().to_string(),
3365                                )
3366                                .trace_err()?,
3367                        )
3368                    } else {
3369                        (
3370                            true,
3371                            live_kit
3372                                .room_token(
3373                                    &joined_room.room.livekit_room,
3374                                    &session.user_id().to_string(),
3375                                )
3376                                .trace_err()?,
3377                        )
3378                    };
3379
3380                    Some(LiveKitConnectionInfo {
3381                        server_url: live_kit.url().into(),
3382                        token,
3383                        can_publish,
3384                    })
3385                });
3386
3387        response.send(proto::JoinRoomResponse {
3388            room: Some(joined_room.room.clone()),
3389            channel_id: joined_room
3390                .channel
3391                .as_ref()
3392                .map(|channel| channel.id.to_proto()),
3393            live_kit_connection_info,
3394        })?;
3395
3396        let mut connection_pool = session.connection_pool().await;
3397        if let Some(membership_updated) = membership_updated {
3398            notify_membership_updated(
3399                &mut connection_pool,
3400                membership_updated,
3401                session.user_id(),
3402                &session.peer,
3403            );
3404        }
3405
3406        room_updated(&joined_room.room, &session.peer);
3407
3408        joined_room
3409    };
3410
3411    channel_updated(
3412        &joined_room.channel.context("channel not returned")?,
3413        &joined_room.room,
3414        &session.peer,
3415        &*session.connection_pool().await,
3416    );
3417
3418    update_user_contacts(session.user_id(), &session).await?;
3419    Ok(())
3420}
3421
3422/// Start editing the channel notes
3423async fn join_channel_buffer(
3424    request: proto::JoinChannelBuffer,
3425    response: Response<proto::JoinChannelBuffer>,
3426    session: MessageContext,
3427) -> Result<()> {
3428    let db = session.db().await;
3429    let channel_id = ChannelId::from_proto(request.channel_id);
3430
3431    let open_response = db
3432        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3433        .await?;
3434
3435    let collaborators = open_response.collaborators.clone();
3436    response.send(open_response)?;
3437
3438    let update = UpdateChannelBufferCollaborators {
3439        channel_id: channel_id.to_proto(),
3440        collaborators: collaborators.clone(),
3441    };
3442    channel_buffer_updated(
3443        session.connection_id,
3444        collaborators
3445            .iter()
3446            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3447        &update,
3448        &session.peer,
3449    );
3450
3451    Ok(())
3452}
3453
3454/// Edit the channel notes
3455async fn update_channel_buffer(
3456    request: proto::UpdateChannelBuffer,
3457    session: MessageContext,
3458) -> Result<()> {
3459    let db = session.db().await;
3460    let channel_id = ChannelId::from_proto(request.channel_id);
3461
3462    let (collaborators, epoch, version) = db
3463        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3464        .await?;
3465
3466    channel_buffer_updated(
3467        session.connection_id,
3468        collaborators.clone(),
3469        &proto::UpdateChannelBuffer {
3470            channel_id: channel_id.to_proto(),
3471            operations: request.operations,
3472        },
3473        &session.peer,
3474    );
3475
3476    let pool = &*session.connection_pool().await;
3477
3478    let non_collaborators =
3479        pool.channel_connection_ids(channel_id)
3480            .filter_map(|(connection_id, _)| {
3481                if collaborators.contains(&connection_id) {
3482                    None
3483                } else {
3484                    Some(connection_id)
3485                }
3486            });
3487
3488    broadcast(None, non_collaborators, |peer_id| {
3489        session.peer.send(
3490            peer_id,
3491            proto::UpdateChannels {
3492                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3493                    channel_id: channel_id.to_proto(),
3494                    epoch: epoch as u64,
3495                    version: version.clone(),
3496                }],
3497                ..Default::default()
3498            },
3499        )
3500    });
3501
3502    Ok(())
3503}
3504
3505/// Rejoin the channel notes after a connection blip
3506async fn rejoin_channel_buffers(
3507    request: proto::RejoinChannelBuffers,
3508    response: Response<proto::RejoinChannelBuffers>,
3509    session: MessageContext,
3510) -> Result<()> {
3511    let db = session.db().await;
3512    let buffers = db
3513        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3514        .await?;
3515
3516    for rejoined_buffer in &buffers {
3517        let collaborators_to_notify = rejoined_buffer
3518            .buffer
3519            .collaborators
3520            .iter()
3521            .filter_map(|c| Some(c.peer_id?.into()));
3522        channel_buffer_updated(
3523            session.connection_id,
3524            collaborators_to_notify,
3525            &proto::UpdateChannelBufferCollaborators {
3526                channel_id: rejoined_buffer.buffer.channel_id,
3527                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3528            },
3529            &session.peer,
3530        );
3531    }
3532
3533    response.send(proto::RejoinChannelBuffersResponse {
3534        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3535    })?;
3536
3537    Ok(())
3538}
3539
3540/// Stop editing the channel notes
3541async fn leave_channel_buffer(
3542    request: proto::LeaveChannelBuffer,
3543    response: Response<proto::LeaveChannelBuffer>,
3544    session: MessageContext,
3545) -> Result<()> {
3546    let db = session.db().await;
3547    let channel_id = ChannelId::from_proto(request.channel_id);
3548
3549    let left_buffer = db
3550        .leave_channel_buffer(channel_id, session.connection_id)
3551        .await?;
3552
3553    response.send(Ack {})?;
3554
3555    channel_buffer_updated(
3556        session.connection_id,
3557        left_buffer.connections,
3558        &proto::UpdateChannelBufferCollaborators {
3559            channel_id: channel_id.to_proto(),
3560            collaborators: left_buffer.collaborators,
3561        },
3562        &session.peer,
3563    );
3564
3565    Ok(())
3566}
3567
3568fn channel_buffer_updated<T: EnvelopedMessage>(
3569    sender_id: ConnectionId,
3570    collaborators: impl IntoIterator<Item = ConnectionId>,
3571    message: &T,
3572    peer: &Peer,
3573) {
3574    broadcast(Some(sender_id), collaborators, |peer_id| {
3575        peer.send(peer_id, message.clone())
3576    });
3577}
3578
3579fn send_notifications(
3580    connection_pool: &ConnectionPool,
3581    peer: &Peer,
3582    notifications: db::NotificationBatch,
3583) {
3584    for (user_id, notification) in notifications {
3585        for connection_id in connection_pool.user_connection_ids(user_id) {
3586            if let Err(error) = peer.send(
3587                connection_id,
3588                proto::AddNotification {
3589                    notification: Some(notification.clone()),
3590                },
3591            ) {
3592                tracing::error!(
3593                    "failed to send notification to {:?} {}",
3594                    connection_id,
3595                    error
3596                );
3597            }
3598        }
3599    }
3600}
3601
3602/// Send a message to the channel
3603async fn send_channel_message(
3604    _request: proto::SendChannelMessage,
3605    _response: Response<proto::SendChannelMessage>,
3606    _session: MessageContext,
3607) -> Result<()> {
3608    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3609}
3610
3611/// Delete a channel message
3612async fn remove_channel_message(
3613    _request: proto::RemoveChannelMessage,
3614    _response: Response<proto::RemoveChannelMessage>,
3615    _session: MessageContext,
3616) -> Result<()> {
3617    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3618}
3619
3620async fn update_channel_message(
3621    _request: proto::UpdateChannelMessage,
3622    _response: Response<proto::UpdateChannelMessage>,
3623    _session: MessageContext,
3624) -> Result<()> {
3625    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3626}
3627
3628/// Mark a channel message as read
3629async fn acknowledge_channel_message(
3630    _request: proto::AckChannelMessage,
3631    _session: MessageContext,
3632) -> Result<()> {
3633    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3634}
3635
3636/// Mark a buffer version as synced
3637async fn acknowledge_buffer_version(
3638    request: proto::AckBufferOperation,
3639    session: MessageContext,
3640) -> Result<()> {
3641    let buffer_id = BufferId::from_proto(request.buffer_id);
3642    session
3643        .db()
3644        .await
3645        .observe_buffer_version(
3646            buffer_id,
3647            session.user_id(),
3648            request.epoch as i32,
3649            &request.version,
3650        )
3651        .await?;
3652    Ok(())
3653}
3654
3655/// Get a Supermaven API key for the user
3656async fn get_supermaven_api_key(
3657    _request: proto::GetSupermavenApiKey,
3658    response: Response<proto::GetSupermavenApiKey>,
3659    session: MessageContext,
3660) -> Result<()> {
3661    let user_id: String = session.user_id().to_string();
3662    if !session.is_staff() {
3663        return Err(anyhow!("supermaven not enabled for this account"))?;
3664    }
3665
3666    let email = session.email().context("user must have an email")?;
3667
3668    let supermaven_admin_api = session
3669        .supermaven_client
3670        .as_ref()
3671        .context("supermaven not configured")?;
3672
3673    let result = supermaven_admin_api
3674        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3675        .await?;
3676
3677    response.send(proto::GetSupermavenApiKeyResponse {
3678        api_key: result.api_key,
3679    })?;
3680
3681    Ok(())
3682}
3683
3684/// Start receiving chat updates for a channel
3685async fn join_channel_chat(
3686    _request: proto::JoinChannelChat,
3687    _response: Response<proto::JoinChannelChat>,
3688    _session: MessageContext,
3689) -> Result<()> {
3690    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3691}
3692
3693/// Stop receiving chat updates for a channel
3694async fn leave_channel_chat(
3695    _request: proto::LeaveChannelChat,
3696    _session: MessageContext,
3697) -> Result<()> {
3698    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3699}
3700
3701/// Retrieve the chat history for a channel
3702async fn get_channel_messages(
3703    _request: proto::GetChannelMessages,
3704    _response: Response<proto::GetChannelMessages>,
3705    _session: MessageContext,
3706) -> Result<()> {
3707    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3708}
3709
3710/// Retrieve specific chat messages
3711async fn get_channel_messages_by_id(
3712    _request: proto::GetChannelMessagesById,
3713    _response: Response<proto::GetChannelMessagesById>,
3714    _session: MessageContext,
3715) -> Result<()> {
3716    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3717}
3718
3719/// Retrieve the current users notifications
3720async fn get_notifications(
3721    request: proto::GetNotifications,
3722    response: Response<proto::GetNotifications>,
3723    session: MessageContext,
3724) -> Result<()> {
3725    let notifications = session
3726        .db()
3727        .await
3728        .get_notifications(
3729            session.user_id(),
3730            NOTIFICATION_COUNT_PER_PAGE,
3731            request.before_id.map(db::NotificationId::from_proto),
3732        )
3733        .await?;
3734    response.send(proto::GetNotificationsResponse {
3735        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3736        notifications,
3737    })?;
3738    Ok(())
3739}
3740
3741/// Mark notifications as read
3742async fn mark_notification_as_read(
3743    request: proto::MarkNotificationRead,
3744    response: Response<proto::MarkNotificationRead>,
3745    session: MessageContext,
3746) -> Result<()> {
3747    let database = &session.db().await;
3748    let notifications = database
3749        .mark_notification_as_read_by_id(
3750            session.user_id(),
3751            NotificationId::from_proto(request.notification_id),
3752        )
3753        .await?;
3754    send_notifications(
3755        &*session.connection_pool().await,
3756        &session.peer,
3757        notifications,
3758    );
3759    response.send(proto::Ack {})?;
3760    Ok(())
3761}
3762
3763fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3764    let message = match message {
3765        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3766        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3767        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3768        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3769        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3770            code: frame.code.into(),
3771            reason: frame.reason.as_str().to_owned().into(),
3772        })),
3773        // We should never receive a frame while reading the message, according
3774        // to the `tungstenite` maintainers:
3775        //
3776        // > It cannot occur when you read messages from the WebSocket, but it
3777        // > can be used when you want to send the raw frames (e.g. you want to
3778        // > send the frames to the WebSocket without composing the full message first).
3779        // >
3780        // > — https://github.com/snapview/tungstenite-rs/issues/268
3781        TungsteniteMessage::Frame(_) => {
3782            bail!("received an unexpected frame while reading the message")
3783        }
3784    };
3785
3786    Ok(message)
3787}
3788
3789fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3790    match message {
3791        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3792        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3793        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3794        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3795        AxumMessage::Close(frame) => {
3796            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3797                code: frame.code.into(),
3798                reason: frame.reason.as_ref().into(),
3799            }))
3800        }
3801    }
3802}
3803
3804fn notify_membership_updated(
3805    connection_pool: &mut ConnectionPool,
3806    result: MembershipUpdated,
3807    user_id: UserId,
3808    peer: &Peer,
3809) {
3810    for membership in &result.new_channels.channel_memberships {
3811        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3812    }
3813    for channel_id in &result.removed_channels {
3814        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3815    }
3816
3817    let user_channels_update = proto::UpdateUserChannels {
3818        channel_memberships: result
3819            .new_channels
3820            .channel_memberships
3821            .iter()
3822            .map(|cm| proto::ChannelMembership {
3823                channel_id: cm.channel_id.to_proto(),
3824                role: cm.role.into(),
3825            })
3826            .collect(),
3827        ..Default::default()
3828    };
3829
3830    let mut update = build_channels_update(result.new_channels);
3831    update.delete_channels = result
3832        .removed_channels
3833        .into_iter()
3834        .map(|id| id.to_proto())
3835        .collect();
3836    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3837
3838    for connection_id in connection_pool.user_connection_ids(user_id) {
3839        peer.send(connection_id, user_channels_update.clone())
3840            .trace_err();
3841        peer.send(connection_id, update.clone()).trace_err();
3842    }
3843}
3844
3845fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3846    proto::UpdateUserChannels {
3847        channel_memberships: channels
3848            .channel_memberships
3849            .iter()
3850            .map(|m| proto::ChannelMembership {
3851                channel_id: m.channel_id.to_proto(),
3852                role: m.role.into(),
3853            })
3854            .collect(),
3855        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3856    }
3857}
3858
3859fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3860    let mut update = proto::UpdateChannels::default();
3861
3862    for channel in channels.channels {
3863        update.channels.push(channel.to_proto());
3864    }
3865
3866    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3867
3868    for (channel_id, participants) in channels.channel_participants {
3869        update
3870            .channel_participants
3871            .push(proto::ChannelParticipants {
3872                channel_id: channel_id.to_proto(),
3873                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3874            });
3875    }
3876
3877    for channel in channels.invited_channels {
3878        update.channel_invitations.push(channel.to_proto());
3879    }
3880
3881    update
3882}
3883
3884fn build_initial_contacts_update(
3885    contacts: Vec<db::Contact>,
3886    pool: &ConnectionPool,
3887) -> proto::UpdateContacts {
3888    let mut update = proto::UpdateContacts::default();
3889
3890    for contact in contacts {
3891        match contact {
3892            db::Contact::Accepted { user_id, busy } => {
3893                update.contacts.push(contact_for_user(user_id, busy, pool));
3894            }
3895            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3896            db::Contact::Incoming { user_id } => {
3897                update
3898                    .incoming_requests
3899                    .push(proto::IncomingContactRequest {
3900                        requester_id: user_id.to_proto(),
3901                    })
3902            }
3903        }
3904    }
3905
3906    update
3907}
3908
3909fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3910    proto::Contact {
3911        user_id: user_id.to_proto(),
3912        online: pool.is_user_online(user_id),
3913        busy,
3914    }
3915}
3916
3917fn room_updated(room: &proto::Room, peer: &Peer) {
3918    broadcast(
3919        None,
3920        room.participants
3921            .iter()
3922            .filter_map(|participant| Some(participant.peer_id?.into())),
3923        |peer_id| {
3924            peer.send(
3925                peer_id,
3926                proto::RoomUpdated {
3927                    room: Some(room.clone()),
3928                },
3929            )
3930        },
3931    );
3932}
3933
3934fn channel_updated(
3935    channel: &db::channel::Model,
3936    room: &proto::Room,
3937    peer: &Peer,
3938    pool: &ConnectionPool,
3939) {
3940    let participants = room
3941        .participants
3942        .iter()
3943        .map(|p| p.user_id)
3944        .collect::<Vec<_>>();
3945
3946    broadcast(
3947        None,
3948        pool.channel_connection_ids(channel.root_id())
3949            .filter_map(|(channel_id, role)| {
3950                role.can_see_channel(channel.visibility)
3951                    .then_some(channel_id)
3952            }),
3953        |peer_id| {
3954            peer.send(
3955                peer_id,
3956                proto::UpdateChannels {
3957                    channel_participants: vec![proto::ChannelParticipants {
3958                        channel_id: channel.id.to_proto(),
3959                        participant_user_ids: participants.clone(),
3960                    }],
3961                    ..Default::default()
3962                },
3963            )
3964        },
3965    );
3966}
3967
3968async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3969    let db = session.db().await;
3970
3971    let contacts = db.get_contacts(user_id).await?;
3972    let busy = db.is_user_busy(user_id).await?;
3973
3974    let pool = session.connection_pool().await;
3975    let updated_contact = contact_for_user(user_id, busy, &pool);
3976    for contact in contacts {
3977        if let db::Contact::Accepted {
3978            user_id: contact_user_id,
3979            ..
3980        } = contact
3981        {
3982            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3983                session
3984                    .peer
3985                    .send(
3986                        contact_conn_id,
3987                        proto::UpdateContacts {
3988                            contacts: vec![updated_contact.clone()],
3989                            remove_contacts: Default::default(),
3990                            incoming_requests: Default::default(),
3991                            remove_incoming_requests: Default::default(),
3992                            outgoing_requests: Default::default(),
3993                            remove_outgoing_requests: Default::default(),
3994                        },
3995                    )
3996                    .trace_err();
3997            }
3998        }
3999    }
4000    Ok(())
4001}
4002
4003async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4004    let mut contacts_to_update = HashSet::default();
4005
4006    let room_id;
4007    let canceled_calls_to_user_ids;
4008    let livekit_room;
4009    let delete_livekit_room;
4010    let room;
4011    let channel;
4012
4013    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4014        contacts_to_update.insert(session.user_id());
4015
4016        for project in left_room.left_projects.values() {
4017            project_left(project, session);
4018        }
4019
4020        room_id = RoomId::from_proto(left_room.room.id);
4021        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4022        livekit_room = mem::take(&mut left_room.room.livekit_room);
4023        delete_livekit_room = left_room.deleted;
4024        room = mem::take(&mut left_room.room);
4025        channel = mem::take(&mut left_room.channel);
4026
4027        room_updated(&room, &session.peer);
4028    } else {
4029        return Ok(());
4030    }
4031
4032    if let Some(channel) = channel {
4033        channel_updated(
4034            &channel,
4035            &room,
4036            &session.peer,
4037            &*session.connection_pool().await,
4038        );
4039    }
4040
4041    {
4042        let pool = session.connection_pool().await;
4043        for canceled_user_id in canceled_calls_to_user_ids {
4044            for connection_id in pool.user_connection_ids(canceled_user_id) {
4045                session
4046                    .peer
4047                    .send(
4048                        connection_id,
4049                        proto::CallCanceled {
4050                            room_id: room_id.to_proto(),
4051                        },
4052                    )
4053                    .trace_err();
4054            }
4055            contacts_to_update.insert(canceled_user_id);
4056        }
4057    }
4058
4059    for contact_user_id in contacts_to_update {
4060        update_user_contacts(contact_user_id, session).await?;
4061    }
4062
4063    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4064        live_kit
4065            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4066            .await
4067            .trace_err();
4068
4069        if delete_livekit_room {
4070            live_kit.delete_room(livekit_room).await.trace_err();
4071        }
4072    }
4073
4074    Ok(())
4075}
4076
4077async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4078    let left_channel_buffers = session
4079        .db()
4080        .await
4081        .leave_channel_buffers(session.connection_id)
4082        .await?;
4083
4084    for left_buffer in left_channel_buffers {
4085        channel_buffer_updated(
4086            session.connection_id,
4087            left_buffer.connections,
4088            &proto::UpdateChannelBufferCollaborators {
4089                channel_id: left_buffer.channel_id.to_proto(),
4090                collaborators: left_buffer.collaborators,
4091            },
4092            &session.peer,
4093        );
4094    }
4095
4096    Ok(())
4097}
4098
4099fn project_left(project: &db::LeftProject, session: &Session) {
4100    for connection_id in &project.connection_ids {
4101        if project.should_unshare {
4102            session
4103                .peer
4104                .send(
4105                    *connection_id,
4106                    proto::UnshareProject {
4107                        project_id: project.id.to_proto(),
4108                    },
4109                )
4110                .trace_err();
4111        } else {
4112            session
4113                .peer
4114                .send(
4115                    *connection_id,
4116                    proto::RemoveProjectCollaborator {
4117                        project_id: project.id.to_proto(),
4118                        peer_id: Some(session.connection_id.into()),
4119                    },
4120                )
4121                .trace_err();
4122        }
4123    }
4124}
4125
4126pub trait ResultExt {
4127    type Ok;
4128
4129    fn trace_err(self) -> Option<Self::Ok>;
4130}
4131
4132impl<T, E> ResultExt for Result<T, E>
4133where
4134    E: std::fmt::Debug,
4135{
4136    type Ok = T;
4137
4138    #[track_caller]
4139    fn trace_err(self) -> Option<T> {
4140        match self {
4141            Ok(value) => Some(value),
4142            Err(error) => {
4143                tracing::error!("{:?}", error);
4144                None
4145            }
4146        }
4147    }
4148}