rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
  10    },
  11    executor::Executor,
  12};
  13use anyhow::{Context as _, anyhow, bail};
  14use async_tungstenite::tungstenite::{
  15    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  16};
  17use axum::headers::UserAgent;
  18use axum::{
  19    Extension, Router, TypedHeader,
  20    body::Body,
  21    extract::{
  22        ConnectInfo, WebSocketUpgrade,
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use core::fmt::{self, Debug, Formatter};
  34use futures::TryFutureExt as _;
  35use reqwest_client::ReqwestClient;
  36use rpc::proto::split_repository_update;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38use tracing::Span;
  39use util::paths::PathStyle;
  40
  41use futures::{
  42    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  43    stream::FuturesUnordered,
  44};
  45use prometheus::{IntGauge, register_int_gauge};
  46use rpc::{
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52};
  53use semantic_version::SemanticVersion;
  54use serde::{Serialize, Serializer};
  55use std::{
  56    any::TypeId,
  57    future::Future,
  58    marker::PhantomData,
  59    mem,
  60    net::SocketAddr,
  61    ops::{Deref, DerefMut},
  62    rc::Rc,
  63    sync::{
  64        Arc, OnceLock,
  65        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  66    },
  67    time::{Duration, Instant},
  68};
  69use tokio::sync::{Semaphore, watch};
  70use tower::ServiceBuilder;
  71use tracing::{
  72    Instrument,
  73    field::{self},
  74    info_span, instrument,
  75};
  76
  77pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  78
  79// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  80pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  81
  82const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  83const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  84
  85static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  86
  87const TOTAL_DURATION_MS: &str = "total_duration_ms";
  88const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  89const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  90const HOST_WAITING_MS: &str = "host_waiting_ms";
  91
  92type MessageHandler =
  93    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  94
  95pub struct ConnectionGuard;
  96
  97impl ConnectionGuard {
  98    pub fn try_acquire() -> Result<Self, ()> {
  99        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 100        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 101            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 102            tracing::error!(
 103                "too many concurrent connections: {}",
 104                current_connections + 1
 105            );
 106            return Err(());
 107        }
 108        Ok(ConnectionGuard)
 109    }
 110}
 111
 112impl Drop for ConnectionGuard {
 113    fn drop(&mut self) {
 114        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 115    }
 116}
 117
 118struct Response<R> {
 119    peer: Arc<Peer>,
 120    receipt: Receipt<R>,
 121    responded: Arc<AtomicBool>,
 122}
 123
 124impl<R: RequestMessage> Response<R> {
 125    fn send(self, payload: R::Response) -> Result<()> {
 126        self.responded.store(true, SeqCst);
 127        self.peer.respond(self.receipt, payload)?;
 128        Ok(())
 129    }
 130}
 131
 132#[derive(Clone, Debug)]
 133pub enum Principal {
 134    User(User),
 135    Impersonated { user: User, admin: User },
 136}
 137
 138impl Principal {
 139    fn update_span(&self, span: &tracing::Span) {
 140        match &self {
 141            Principal::User(user) => {
 142                span.record("user_id", user.id.0);
 143                span.record("login", &user.github_login);
 144            }
 145            Principal::Impersonated { user, admin } => {
 146                span.record("user_id", user.id.0);
 147                span.record("login", &user.github_login);
 148                span.record("impersonator", &admin.github_login);
 149            }
 150        }
 151    }
 152}
 153
 154#[derive(Clone)]
 155struct MessageContext {
 156    session: Session,
 157    span: tracing::Span,
 158}
 159
 160impl Deref for MessageContext {
 161    type Target = Session;
 162
 163    fn deref(&self) -> &Self::Target {
 164        &self.session
 165    }
 166}
 167
 168impl MessageContext {
 169    pub fn forward_request<T: RequestMessage>(
 170        &self,
 171        receiver_id: ConnectionId,
 172        request: T,
 173    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 174        let request_start_time = Instant::now();
 175        let span = self.span.clone();
 176        tracing::info!("start forwarding request");
 177        self.peer
 178            .forward_request(self.connection_id, receiver_id, request)
 179            .inspect(move |_| {
 180                span.record(
 181                    HOST_WAITING_MS,
 182                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 183                );
 184            })
 185            .inspect_err(|_| tracing::error!("error forwarding request"))
 186            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 187    }
 188}
 189
 190#[derive(Clone)]
 191struct Session {
 192    principal: Principal,
 193    connection_id: ConnectionId,
 194    db: Arc<tokio::sync::Mutex<DbHandle>>,
 195    peer: Arc<Peer>,
 196    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 197    app_state: Arc<AppState>,
 198    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 199    /// The GeoIP country code for the user.
 200    #[allow(unused)]
 201    geoip_country_code: Option<String>,
 202    #[allow(unused)]
 203    system_id: Option<String>,
 204    _executor: Executor,
 205}
 206
 207impl Session {
 208    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 209        #[cfg(test)]
 210        tokio::task::yield_now().await;
 211        let guard = self.db.lock().await;
 212        #[cfg(test)]
 213        tokio::task::yield_now().await;
 214        guard
 215    }
 216
 217    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 218        #[cfg(test)]
 219        tokio::task::yield_now().await;
 220        let guard = self.connection_pool.lock();
 221        ConnectionPoolGuard {
 222            guard,
 223            _not_send: PhantomData,
 224        }
 225    }
 226
 227    fn is_staff(&self) -> bool {
 228        match &self.principal {
 229            Principal::User(user) => user.admin,
 230            Principal::Impersonated { .. } => true,
 231        }
 232    }
 233
 234    fn user_id(&self) -> UserId {
 235        match &self.principal {
 236            Principal::User(user) => user.id,
 237            Principal::Impersonated { user, .. } => user.id,
 238        }
 239    }
 240
 241    pub fn email(&self) -> Option<String> {
 242        match &self.principal {
 243            Principal::User(user) => user.email_address.clone(),
 244            Principal::Impersonated { user, .. } => user.email_address.clone(),
 245        }
 246    }
 247}
 248
 249impl Debug for Session {
 250    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 251        let mut result = f.debug_struct("Session");
 252        match &self.principal {
 253            Principal::User(user) => {
 254                result.field("user", &user.github_login);
 255            }
 256            Principal::Impersonated { user, admin } => {
 257                result.field("user", &user.github_login);
 258                result.field("impersonator", &admin.github_login);
 259            }
 260        }
 261        result.field("connection_id", &self.connection_id).finish()
 262    }
 263}
 264
 265struct DbHandle(Arc<Database>);
 266
 267impl Deref for DbHandle {
 268    type Target = Database;
 269
 270    fn deref(&self) -> &Self::Target {
 271        self.0.as_ref()
 272    }
 273}
 274
 275pub struct Server {
 276    id: parking_lot::Mutex<ServerId>,
 277    peer: Arc<Peer>,
 278    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 279    app_state: Arc<AppState>,
 280    handlers: HashMap<TypeId, MessageHandler>,
 281    teardown: watch::Sender<bool>,
 282}
 283
 284pub(crate) struct ConnectionPoolGuard<'a> {
 285    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 286    _not_send: PhantomData<Rc<()>>,
 287}
 288
 289#[derive(Serialize)]
 290pub struct ServerSnapshot<'a> {
 291    peer: &'a Peer,
 292    #[serde(serialize_with = "serialize_deref")]
 293    connection_pool: ConnectionPoolGuard<'a>,
 294}
 295
 296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 297where
 298    S: Serializer,
 299    T: Deref<Target = U>,
 300    U: Serialize,
 301{
 302    Serialize::serialize(value.deref(), serializer)
 303}
 304
 305impl Server {
 306    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 307        let mut server = Self {
 308            id: parking_lot::Mutex::new(id),
 309            peer: Peer::new(id.0 as u32),
 310            app_state,
 311            connection_pool: Default::default(),
 312            handlers: Default::default(),
 313            teardown: watch::channel(false).0,
 314        };
 315
 316        server
 317            .add_request_handler(ping)
 318            .add_request_handler(create_room)
 319            .add_request_handler(join_room)
 320            .add_request_handler(rejoin_room)
 321            .add_request_handler(leave_room)
 322            .add_request_handler(set_room_participant_role)
 323            .add_request_handler(call)
 324            .add_request_handler(cancel_call)
 325            .add_message_handler(decline_call)
 326            .add_request_handler(update_participant_location)
 327            .add_request_handler(share_project)
 328            .add_message_handler(unshare_project)
 329            .add_request_handler(join_project)
 330            .add_message_handler(leave_project)
 331            .add_request_handler(update_project)
 332            .add_request_handler(update_worktree)
 333            .add_request_handler(update_repository)
 334            .add_request_handler(remove_repository)
 335            .add_message_handler(start_language_server)
 336            .add_message_handler(update_language_server)
 337            .add_message_handler(update_diagnostic_summary)
 338            .add_message_handler(update_worktree_settings)
 339            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 340            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 341            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 342            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 343            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 344            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 345            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 346            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 347            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 348            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 349            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 350            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 351            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 352            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 353            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 354            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 355            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 356            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 357            .add_request_handler(
 358                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 359            )
 360            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 361            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 362            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 363            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 364            .add_request_handler(
 365                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 366            )
 367            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 368            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 369            .add_request_handler(
 370                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 371            )
 372            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 373            .add_request_handler(
 374                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 375            )
 376            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 377            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 378            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 379            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 380            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 381            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 382            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 383            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 384            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 385            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 386            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 387            .add_request_handler(
 388                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 389            )
 390            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 391            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 392            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 393            .add_request_handler(lsp_query)
 394            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 395            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 396            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 397            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 398            .add_message_handler(create_buffer_for_peer)
 399            .add_message_handler(create_image_for_peer)
 400            .add_request_handler(update_buffer)
 401            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 402            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 403            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 404            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 405            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 406            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 407            .add_message_handler(
 408                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 409            )
 410            .add_request_handler(get_users)
 411            .add_request_handler(fuzzy_search_users)
 412            .add_request_handler(request_contact)
 413            .add_request_handler(remove_contact)
 414            .add_request_handler(respond_to_contact_request)
 415            .add_message_handler(subscribe_to_channels)
 416            .add_request_handler(create_channel)
 417            .add_request_handler(delete_channel)
 418            .add_request_handler(invite_channel_member)
 419            .add_request_handler(remove_channel_member)
 420            .add_request_handler(set_channel_member_role)
 421            .add_request_handler(set_channel_visibility)
 422            .add_request_handler(rename_channel)
 423            .add_request_handler(join_channel_buffer)
 424            .add_request_handler(leave_channel_buffer)
 425            .add_message_handler(update_channel_buffer)
 426            .add_request_handler(rejoin_channel_buffers)
 427            .add_request_handler(get_channel_members)
 428            .add_request_handler(respond_to_channel_invite)
 429            .add_request_handler(join_channel)
 430            .add_request_handler(join_channel_chat)
 431            .add_message_handler(leave_channel_chat)
 432            .add_request_handler(send_channel_message)
 433            .add_request_handler(remove_channel_message)
 434            .add_request_handler(update_channel_message)
 435            .add_request_handler(get_channel_messages)
 436            .add_request_handler(get_channel_messages_by_id)
 437            .add_request_handler(get_notifications)
 438            .add_request_handler(mark_notification_as_read)
 439            .add_request_handler(move_channel)
 440            .add_request_handler(reorder_channel)
 441            .add_request_handler(follow)
 442            .add_message_handler(unfollow)
 443            .add_message_handler(update_followers)
 444            .add_message_handler(acknowledge_channel_message)
 445            .add_message_handler(acknowledge_buffer_version)
 446            .add_request_handler(get_supermaven_api_key)
 447            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 448            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 449            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 450            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 451            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 452            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 453            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 454            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 455            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 456            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 457            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 458            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 459            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 460            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 461            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 462            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 463            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 464            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 465            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 466            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 467            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 468            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 469            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 470            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 471            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 472            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 473            .add_message_handler(update_context)
 474            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 475            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
 476
 477        Arc::new(server)
 478    }
 479
 480    pub async fn start(&self) -> Result<()> {
 481        let server_id = *self.id.lock();
 482        let app_state = self.app_state.clone();
 483        let peer = self.peer.clone();
 484        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 485        let pool = self.connection_pool.clone();
 486        let livekit_client = self.app_state.livekit_client.clone();
 487
 488        let span = info_span!("start server");
 489        self.app_state.executor.spawn_detached(
 490            async move {
 491                tracing::info!("waiting for cleanup timeout");
 492                timeout.await;
 493                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 494
 495                app_state
 496                    .db
 497                    .delete_stale_channel_chat_participants(
 498                        &app_state.config.zed_environment,
 499                        server_id,
 500                    )
 501                    .await
 502                    .trace_err();
 503
 504                if let Some((room_ids, channel_ids)) = app_state
 505                    .db
 506                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 507                    .await
 508                    .trace_err()
 509                {
 510                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 511                    tracing::info!(
 512                        stale_channel_buffer_count = channel_ids.len(),
 513                        "retrieved stale channel buffers"
 514                    );
 515
 516                    for channel_id in channel_ids {
 517                        if let Some(refreshed_channel_buffer) = app_state
 518                            .db
 519                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 520                            .await
 521                            .trace_err()
 522                        {
 523                            for connection_id in refreshed_channel_buffer.connection_ids {
 524                                peer.send(
 525                                    connection_id,
 526                                    proto::UpdateChannelBufferCollaborators {
 527                                        channel_id: channel_id.to_proto(),
 528                                        collaborators: refreshed_channel_buffer
 529                                            .collaborators
 530                                            .clone(),
 531                                    },
 532                                )
 533                                .trace_err();
 534                            }
 535                        }
 536                    }
 537
 538                    for room_id in room_ids {
 539                        let mut contacts_to_update = HashSet::default();
 540                        let mut canceled_calls_to_user_ids = Vec::new();
 541                        let mut livekit_room = String::new();
 542                        let mut delete_livekit_room = false;
 543
 544                        if let Some(mut refreshed_room) = app_state
 545                            .db
 546                            .clear_stale_room_participants(room_id, server_id)
 547                            .await
 548                            .trace_err()
 549                        {
 550                            tracing::info!(
 551                                room_id = room_id.0,
 552                                new_participant_count = refreshed_room.room.participants.len(),
 553                                "refreshed room"
 554                            );
 555                            room_updated(&refreshed_room.room, &peer);
 556                            if let Some(channel) = refreshed_room.channel.as_ref() {
 557                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 558                            }
 559                            contacts_to_update
 560                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 561                            contacts_to_update
 562                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 563                            canceled_calls_to_user_ids =
 564                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 565                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 566                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 567                        }
 568
 569                        {
 570                            let pool = pool.lock();
 571                            for canceled_user_id in canceled_calls_to_user_ids {
 572                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 573                                    peer.send(
 574                                        connection_id,
 575                                        proto::CallCanceled {
 576                                            room_id: room_id.to_proto(),
 577                                        },
 578                                    )
 579                                    .trace_err();
 580                                }
 581                            }
 582                        }
 583
 584                        for user_id in contacts_to_update {
 585                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 586                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 587                            if let Some((busy, contacts)) = busy.zip(contacts) {
 588                                let pool = pool.lock();
 589                                let updated_contact = contact_for_user(user_id, busy, &pool);
 590                                for contact in contacts {
 591                                    if let db::Contact::Accepted {
 592                                        user_id: contact_user_id,
 593                                        ..
 594                                    } = contact
 595                                    {
 596                                        for contact_conn_id in
 597                                            pool.user_connection_ids(contact_user_id)
 598                                        {
 599                                            peer.send(
 600                                                contact_conn_id,
 601                                                proto::UpdateContacts {
 602                                                    contacts: vec![updated_contact.clone()],
 603                                                    remove_contacts: Default::default(),
 604                                                    incoming_requests: Default::default(),
 605                                                    remove_incoming_requests: Default::default(),
 606                                                    outgoing_requests: Default::default(),
 607                                                    remove_outgoing_requests: Default::default(),
 608                                                },
 609                                            )
 610                                            .trace_err();
 611                                        }
 612                                    }
 613                                }
 614                            }
 615                        }
 616
 617                        if let Some(live_kit) = livekit_client.as_ref()
 618                            && delete_livekit_room
 619                        {
 620                            live_kit.delete_room(livekit_room).await.trace_err();
 621                        }
 622                    }
 623                }
 624
 625                app_state
 626                    .db
 627                    .delete_stale_channel_chat_participants(
 628                        &app_state.config.zed_environment,
 629                        server_id,
 630                    )
 631                    .await
 632                    .trace_err();
 633
 634                app_state
 635                    .db
 636                    .clear_old_worktree_entries(server_id)
 637                    .await
 638                    .trace_err();
 639
 640                app_state
 641                    .db
 642                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 643                    .await
 644                    .trace_err();
 645            }
 646            .instrument(span),
 647        );
 648        Ok(())
 649    }
 650
 651    pub fn teardown(&self) {
 652        self.peer.teardown();
 653        self.connection_pool.lock().reset();
 654        let _ = self.teardown.send(true);
 655    }
 656
 657    #[cfg(test)]
 658    pub fn reset(&self, id: ServerId) {
 659        self.teardown();
 660        *self.id.lock() = id;
 661        self.peer.reset(id.0 as u32);
 662        let _ = self.teardown.send(false);
 663    }
 664
 665    #[cfg(test)]
 666    pub fn id(&self) -> ServerId {
 667        *self.id.lock()
 668    }
 669
 670    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 671    where
 672        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 673        Fut: 'static + Send + Future<Output = Result<()>>,
 674        M: EnvelopedMessage,
 675    {
 676        let prev_handler = self.handlers.insert(
 677            TypeId::of::<M>(),
 678            Box::new(move |envelope, session, span| {
 679                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 680                let received_at = envelope.received_at;
 681                tracing::info!("message received");
 682                let start_time = Instant::now();
 683                let future = (handler)(
 684                    *envelope,
 685                    MessageContext {
 686                        session,
 687                        span: span.clone(),
 688                    },
 689                );
 690                async move {
 691                    let result = future.await;
 692                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 693                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 694                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 695                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 696                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 697                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 698                    match result {
 699                        Err(error) => {
 700                            tracing::error!(?error, "error handling message")
 701                        }
 702                        Ok(()) => tracing::info!("finished handling message"),
 703                    }
 704                }
 705                .boxed()
 706            }),
 707        );
 708        if prev_handler.is_some() {
 709            panic!("registered a handler for the same message twice");
 710        }
 711        self
 712    }
 713
 714    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 715    where
 716        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 717        Fut: 'static + Send + Future<Output = Result<()>>,
 718        M: EnvelopedMessage,
 719    {
 720        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 721        self
 722    }
 723
 724    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 725    where
 726        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 727        Fut: Send + Future<Output = Result<()>>,
 728        M: RequestMessage,
 729    {
 730        let handler = Arc::new(handler);
 731        self.add_handler(move |envelope, session| {
 732            let receipt = envelope.receipt();
 733            let handler = handler.clone();
 734            async move {
 735                let peer = session.peer.clone();
 736                let responded = Arc::new(AtomicBool::default());
 737                let response = Response {
 738                    peer: peer.clone(),
 739                    responded: responded.clone(),
 740                    receipt,
 741                };
 742                match (handler)(envelope.payload, response, session).await {
 743                    Ok(()) => {
 744                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 745                            Ok(())
 746                        } else {
 747                            Err(anyhow!("handler did not send a response"))?
 748                        }
 749                    }
 750                    Err(error) => {
 751                        let proto_err = match &error {
 752                            Error::Internal(err) => err.to_proto(),
 753                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 754                        };
 755                        peer.respond_with_error(receipt, proto_err)?;
 756                        Err(error)
 757                    }
 758                }
 759            }
 760        })
 761    }
 762
 763    pub fn handle_connection(
 764        self: &Arc<Self>,
 765        connection: Connection,
 766        address: String,
 767        principal: Principal,
 768        zed_version: ZedVersion,
 769        release_channel: Option<String>,
 770        user_agent: Option<String>,
 771        geoip_country_code: Option<String>,
 772        system_id: Option<String>,
 773        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 774        executor: Executor,
 775        connection_guard: Option<ConnectionGuard>,
 776    ) -> impl Future<Output = ()> + use<> {
 777        let this = self.clone();
 778        let span = info_span!("handle connection", %address,
 779            connection_id=field::Empty,
 780            user_id=field::Empty,
 781            login=field::Empty,
 782            impersonator=field::Empty,
 783            user_agent=field::Empty,
 784            geoip_country_code=field::Empty,
 785            release_channel=field::Empty,
 786        );
 787        principal.update_span(&span);
 788        if let Some(user_agent) = user_agent {
 789            span.record("user_agent", user_agent);
 790        }
 791        if let Some(release_channel) = release_channel {
 792            span.record("release_channel", release_channel);
 793        }
 794
 795        if let Some(country_code) = geoip_country_code.as_ref() {
 796            span.record("geoip_country_code", country_code);
 797        }
 798
 799        let mut teardown = self.teardown.subscribe();
 800        async move {
 801            if *teardown.borrow() {
 802                tracing::error!("server is tearing down");
 803                return;
 804            }
 805
 806            let (connection_id, handle_io, mut incoming_rx) =
 807                this.peer.add_connection(connection, {
 808                    let executor = executor.clone();
 809                    move |duration| executor.sleep(duration)
 810                });
 811            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 812
 813            tracing::info!("connection opened");
 814
 815            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 816            let http_client = match ReqwestClient::user_agent(&user_agent) {
 817                Ok(http_client) => Arc::new(http_client),
 818                Err(error) => {
 819                    tracing::error!(?error, "failed to create HTTP client");
 820                    return;
 821                }
 822            };
 823
 824            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 825                |supermaven_admin_api_key| {
 826                    Arc::new(SupermavenAdminApi::new(
 827                        supermaven_admin_api_key.to_string(),
 828                        http_client.clone(),
 829                    ))
 830                },
 831            );
 832
 833            let session = Session {
 834                principal: principal.clone(),
 835                connection_id,
 836                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 837                peer: this.peer.clone(),
 838                connection_pool: this.connection_pool.clone(),
 839                app_state: this.app_state.clone(),
 840                geoip_country_code,
 841                system_id,
 842                _executor: executor.clone(),
 843                supermaven_client,
 844            };
 845
 846            if let Err(error) = this
 847                .send_initial_client_update(
 848                    connection_id,
 849                    zed_version,
 850                    send_connection_id,
 851                    &session,
 852                )
 853                .await
 854            {
 855                tracing::error!(?error, "failed to send initial client update");
 856                return;
 857            }
 858            drop(connection_guard);
 859
 860            let handle_io = handle_io.fuse();
 861            futures::pin_mut!(handle_io);
 862
 863            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 864            // This prevents deadlocks when e.g., client A performs a request to client B and
 865            // client B performs a request to client A. If both clients stop processing further
 866            // messages until their respective request completes, they won't have a chance to
 867            // respond to the other client's request and cause a deadlock.
 868            //
 869            // This arrangement ensures we will attempt to process earlier messages first, but fall
 870            // back to processing messages arrived later in the spirit of making progress.
 871            const MAX_CONCURRENT_HANDLERS: usize = 256;
 872            let mut foreground_message_handlers = FuturesUnordered::new();
 873            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 874            let get_concurrent_handlers = {
 875                let concurrent_handlers = concurrent_handlers.clone();
 876                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 877            };
 878            loop {
 879                let next_message = async {
 880                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 881                    let message = incoming_rx.next().await;
 882                    // Cache the concurrent_handlers here, so that we know what the
 883                    // queue looks like as each handler starts
 884                    (permit, message, get_concurrent_handlers())
 885                }
 886                .fuse();
 887                futures::pin_mut!(next_message);
 888                futures::select_biased! {
 889                    _ = teardown.changed().fuse() => return,
 890                    result = handle_io => {
 891                        if let Err(error) = result {
 892                            tracing::error!(?error, "error handling I/O");
 893                        }
 894                        break;
 895                    }
 896                    _ = foreground_message_handlers.next() => {}
 897                    next_message = next_message => {
 898                        let (permit, message, concurrent_handlers) = next_message;
 899                        if let Some(message) = message {
 900                            let type_name = message.payload_type_name();
 901                            // note: we copy all the fields from the parent span so we can query them in the logs.
 902                            // (https://github.com/tokio-rs/tracing/issues/2670).
 903                            let span = tracing::info_span!("receive message",
 904                                %connection_id,
 905                                %address,
 906                                type_name,
 907                                concurrent_handlers,
 908                                user_id=field::Empty,
 909                                login=field::Empty,
 910                                impersonator=field::Empty,
 911                                lsp_query_request=field::Empty,
 912                                release_channel=field::Empty,
 913                                { TOTAL_DURATION_MS }=field::Empty,
 914                                { PROCESSING_DURATION_MS }=field::Empty,
 915                                { QUEUE_DURATION_MS }=field::Empty,
 916                                { HOST_WAITING_MS }=field::Empty
 917                            );
 918                            principal.update_span(&span);
 919                            let span_enter = span.enter();
 920                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 921                                let is_background = message.is_background();
 922                                let handle_message = (handler)(message, session.clone(), span.clone());
 923                                drop(span_enter);
 924
 925                                let handle_message = async move {
 926                                    handle_message.await;
 927                                    drop(permit);
 928                                }.instrument(span);
 929                                if is_background {
 930                                    executor.spawn_detached(handle_message);
 931                                } else {
 932                                    foreground_message_handlers.push(handle_message);
 933                                }
 934                            } else {
 935                                tracing::error!("no message handler");
 936                            }
 937                        } else {
 938                            tracing::info!("connection closed");
 939                            break;
 940                        }
 941                    }
 942                }
 943            }
 944
 945            drop(foreground_message_handlers);
 946            let concurrent_handlers = get_concurrent_handlers();
 947            tracing::info!(concurrent_handlers, "signing out");
 948            if let Err(error) = connection_lost(session, teardown, executor).await {
 949                tracing::error!(?error, "error signing out");
 950            }
 951        }
 952        .instrument(span)
 953    }
 954
 955    async fn send_initial_client_update(
 956        &self,
 957        connection_id: ConnectionId,
 958        zed_version: ZedVersion,
 959        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 960        session: &Session,
 961    ) -> Result<()> {
 962        self.peer.send(
 963            connection_id,
 964            proto::Hello {
 965                peer_id: Some(connection_id.into()),
 966            },
 967        )?;
 968        tracing::info!("sent hello message");
 969        if let Some(send_connection_id) = send_connection_id.take() {
 970            let _ = send_connection_id.send(connection_id);
 971        }
 972
 973        match &session.principal {
 974            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 975                if !user.connected_once {
 976                    self.peer.send(connection_id, proto::ShowContacts {})?;
 977                    self.app_state
 978                        .db
 979                        .set_user_connected_once(user.id, true)
 980                        .await?;
 981                }
 982
 983                let contacts = self.app_state.db.get_contacts(user.id).await?;
 984
 985                {
 986                    let mut pool = self.connection_pool.lock();
 987                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 988                    self.peer.send(
 989                        connection_id,
 990                        build_initial_contacts_update(contacts, &pool),
 991                    )?;
 992                }
 993
 994                if should_auto_subscribe_to_channels(zed_version) {
 995                    subscribe_user_to_channels(user.id, session).await?;
 996                }
 997
 998                if let Some(incoming_call) =
 999                    self.app_state.db.incoming_call_for_user(user.id).await?
1000                {
1001                    self.peer.send(connection_id, incoming_call)?;
1002                }
1003
1004                update_user_contacts(user.id, session).await?;
1005            }
1006        }
1007
1008        Ok(())
1009    }
1010
1011    pub async fn invite_code_redeemed(
1012        self: &Arc<Self>,
1013        inviter_id: UserId,
1014        invitee_id: UserId,
1015    ) -> Result<()> {
1016        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1017            && let Some(code) = &user.invite_code
1018        {
1019            let pool = self.connection_pool.lock();
1020            let invitee_contact = contact_for_user(invitee_id, false, &pool);
1021            for connection_id in pool.user_connection_ids(inviter_id) {
1022                self.peer.send(
1023                    connection_id,
1024                    proto::UpdateContacts {
1025                        contacts: vec![invitee_contact.clone()],
1026                        ..Default::default()
1027                    },
1028                )?;
1029                self.peer.send(
1030                    connection_id,
1031                    proto::UpdateInviteInfo {
1032                        url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1033                        count: user.invite_count as u32,
1034                    },
1035                )?;
1036            }
1037        }
1038        Ok(())
1039    }
1040
1041    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1042        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1043            && let Some(invite_code) = &user.invite_code
1044        {
1045            let pool = self.connection_pool.lock();
1046            for connection_id in pool.user_connection_ids(user_id) {
1047                self.peer.send(
1048                    connection_id,
1049                    proto::UpdateInviteInfo {
1050                        url: format!(
1051                            "{}{}",
1052                            self.app_state.config.invite_link_prefix, invite_code
1053                        ),
1054                        count: user.invite_count as u32,
1055                    },
1056                )?;
1057            }
1058        }
1059        Ok(())
1060    }
1061
1062    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1063        ServerSnapshot {
1064            connection_pool: ConnectionPoolGuard {
1065                guard: self.connection_pool.lock(),
1066                _not_send: PhantomData,
1067            },
1068            peer: &self.peer,
1069        }
1070    }
1071}
1072
1073impl Deref for ConnectionPoolGuard<'_> {
1074    type Target = ConnectionPool;
1075
1076    fn deref(&self) -> &Self::Target {
1077        &self.guard
1078    }
1079}
1080
1081impl DerefMut for ConnectionPoolGuard<'_> {
1082    fn deref_mut(&mut self) -> &mut Self::Target {
1083        &mut self.guard
1084    }
1085}
1086
1087impl Drop for ConnectionPoolGuard<'_> {
1088    fn drop(&mut self) {
1089        #[cfg(test)]
1090        self.check_invariants();
1091    }
1092}
1093
1094fn broadcast<F>(
1095    sender_id: Option<ConnectionId>,
1096    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1097    mut f: F,
1098) where
1099    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1100{
1101    for receiver_id in receiver_ids {
1102        if Some(receiver_id) != sender_id
1103            && let Err(error) = f(receiver_id)
1104        {
1105            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1106        }
1107    }
1108}
1109
1110pub struct ProtocolVersion(u32);
1111
1112impl Header for ProtocolVersion {
1113    fn name() -> &'static HeaderName {
1114        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1115        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1116    }
1117
1118    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1119    where
1120        Self: Sized,
1121        I: Iterator<Item = &'i axum::http::HeaderValue>,
1122    {
1123        let version = values
1124            .next()
1125            .ok_or_else(axum::headers::Error::invalid)?
1126            .to_str()
1127            .map_err(|_| axum::headers::Error::invalid())?
1128            .parse()
1129            .map_err(|_| axum::headers::Error::invalid())?;
1130        Ok(Self(version))
1131    }
1132
1133    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1134        values.extend([self.0.to_string().parse().unwrap()]);
1135    }
1136}
1137
1138pub struct AppVersionHeader(SemanticVersion);
1139impl Header for AppVersionHeader {
1140    fn name() -> &'static HeaderName {
1141        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1142        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1143    }
1144
1145    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1146    where
1147        Self: Sized,
1148        I: Iterator<Item = &'i axum::http::HeaderValue>,
1149    {
1150        let version = values
1151            .next()
1152            .ok_or_else(axum::headers::Error::invalid)?
1153            .to_str()
1154            .map_err(|_| axum::headers::Error::invalid())?
1155            .parse()
1156            .map_err(|_| axum::headers::Error::invalid())?;
1157        Ok(Self(version))
1158    }
1159
1160    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1161        values.extend([self.0.to_string().parse().unwrap()]);
1162    }
1163}
1164
1165#[derive(Debug)]
1166pub struct ReleaseChannelHeader(String);
1167
1168impl Header for ReleaseChannelHeader {
1169    fn name() -> &'static HeaderName {
1170        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1171        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1172    }
1173
1174    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1175    where
1176        Self: Sized,
1177        I: Iterator<Item = &'i axum::http::HeaderValue>,
1178    {
1179        Ok(Self(
1180            values
1181                .next()
1182                .ok_or_else(axum::headers::Error::invalid)?
1183                .to_str()
1184                .map_err(|_| axum::headers::Error::invalid())?
1185                .to_owned(),
1186        ))
1187    }
1188
1189    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1190        values.extend([self.0.parse().unwrap()]);
1191    }
1192}
1193
1194pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1195    Router::new()
1196        .route("/rpc", get(handle_websocket_request))
1197        .layer(
1198            ServiceBuilder::new()
1199                .layer(Extension(server.app_state.clone()))
1200                .layer(middleware::from_fn(auth::validate_header)),
1201        )
1202        .route("/metrics", get(handle_metrics))
1203        .layer(Extension(server))
1204}
1205
1206pub async fn handle_websocket_request(
1207    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1208    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1209    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1210    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1211    Extension(server): Extension<Arc<Server>>,
1212    Extension(principal): Extension<Principal>,
1213    user_agent: Option<TypedHeader<UserAgent>>,
1214    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1215    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1216    ws: WebSocketUpgrade,
1217) -> axum::response::Response {
1218    if protocol_version != rpc::PROTOCOL_VERSION {
1219        return (
1220            StatusCode::UPGRADE_REQUIRED,
1221            "client must be upgraded".to_string(),
1222        )
1223            .into_response();
1224    }
1225
1226    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1227        return (
1228            StatusCode::UPGRADE_REQUIRED,
1229            "no version header found".to_string(),
1230        )
1231            .into_response();
1232    };
1233
1234    let release_channel = release_channel_header.map(|header| header.0.0);
1235
1236    if !version.can_collaborate() {
1237        return (
1238            StatusCode::UPGRADE_REQUIRED,
1239            "client must be upgraded".to_string(),
1240        )
1241            .into_response();
1242    }
1243
1244    let socket_address = socket_address.to_string();
1245
1246    // Acquire connection guard before WebSocket upgrade
1247    let connection_guard = match ConnectionGuard::try_acquire() {
1248        Ok(guard) => guard,
1249        Err(()) => {
1250            return (
1251                StatusCode::SERVICE_UNAVAILABLE,
1252                "Too many concurrent connections",
1253            )
1254                .into_response();
1255        }
1256    };
1257
1258    ws.on_upgrade(move |socket| {
1259        let socket = socket
1260            .map_ok(to_tungstenite_message)
1261            .err_into()
1262            .with(|message| async move { to_axum_message(message) });
1263        let connection = Connection::new(Box::pin(socket));
1264        async move {
1265            server
1266                .handle_connection(
1267                    connection,
1268                    socket_address,
1269                    principal,
1270                    version,
1271                    release_channel,
1272                    user_agent.map(|header| header.to_string()),
1273                    country_code_header.map(|header| header.to_string()),
1274                    system_id_header.map(|header| header.to_string()),
1275                    None,
1276                    Executor::Production,
1277                    Some(connection_guard),
1278                )
1279                .await;
1280        }
1281    })
1282}
1283
1284pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1285    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1286    let connections_metric = CONNECTIONS_METRIC
1287        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1288
1289    let connections = server
1290        .connection_pool
1291        .lock()
1292        .connections()
1293        .filter(|connection| !connection.admin)
1294        .count();
1295    connections_metric.set(connections as _);
1296
1297    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1298    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1299        register_int_gauge!(
1300            "shared_projects",
1301            "number of open projects with one or more guests"
1302        )
1303        .unwrap()
1304    });
1305
1306    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1307    shared_projects_metric.set(shared_projects as _);
1308
1309    let encoder = prometheus::TextEncoder::new();
1310    let metric_families = prometheus::gather();
1311    let encoded_metrics = encoder
1312        .encode_to_string(&metric_families)
1313        .map_err(|err| anyhow!("{err}"))?;
1314    Ok(encoded_metrics)
1315}
1316
1317#[instrument(err, skip(executor))]
1318async fn connection_lost(
1319    session: Session,
1320    mut teardown: watch::Receiver<bool>,
1321    executor: Executor,
1322) -> Result<()> {
1323    session.peer.disconnect(session.connection_id);
1324    session
1325        .connection_pool()
1326        .await
1327        .remove_connection(session.connection_id)?;
1328
1329    session
1330        .db()
1331        .await
1332        .connection_lost(session.connection_id)
1333        .await
1334        .trace_err();
1335
1336    futures::select_biased! {
1337        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1338
1339            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1340            leave_room_for_session(&session, session.connection_id).await.trace_err();
1341            leave_channel_buffers_for_session(&session)
1342                .await
1343                .trace_err();
1344
1345            if !session
1346                .connection_pool()
1347                .await
1348                .is_user_online(session.user_id())
1349            {
1350                let db = session.db().await;
1351                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1352                    room_updated(&room, &session.peer);
1353                }
1354            }
1355
1356            update_user_contacts(session.user_id(), &session).await?;
1357        },
1358        _ = teardown.changed().fuse() => {}
1359    }
1360
1361    Ok(())
1362}
1363
1364/// Acknowledges a ping from a client, used to keep the connection alive.
1365async fn ping(
1366    _: proto::Ping,
1367    response: Response<proto::Ping>,
1368    _session: MessageContext,
1369) -> Result<()> {
1370    response.send(proto::Ack {})?;
1371    Ok(())
1372}
1373
1374/// Creates a new room for calling (outside of channels)
1375async fn create_room(
1376    _request: proto::CreateRoom,
1377    response: Response<proto::CreateRoom>,
1378    session: MessageContext,
1379) -> Result<()> {
1380    let livekit_room = nanoid::nanoid!(30);
1381
1382    let live_kit_connection_info = util::maybe!(async {
1383        let live_kit = session.app_state.livekit_client.as_ref();
1384        let live_kit = live_kit?;
1385        let user_id = session.user_id().to_string();
1386
1387        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1388
1389        Some(proto::LiveKitConnectionInfo {
1390            server_url: live_kit.url().into(),
1391            token,
1392            can_publish: true,
1393        })
1394    })
1395    .await;
1396
1397    let room = session
1398        .db()
1399        .await
1400        .create_room(session.user_id(), session.connection_id, &livekit_room)
1401        .await?;
1402
1403    response.send(proto::CreateRoomResponse {
1404        room: Some(room.clone()),
1405        live_kit_connection_info,
1406    })?;
1407
1408    update_user_contacts(session.user_id(), &session).await?;
1409    Ok(())
1410}
1411
1412/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1413async fn join_room(
1414    request: proto::JoinRoom,
1415    response: Response<proto::JoinRoom>,
1416    session: MessageContext,
1417) -> Result<()> {
1418    let room_id = RoomId::from_proto(request.id);
1419
1420    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1421
1422    if let Some(channel_id) = channel_id {
1423        return join_channel_internal(channel_id, Box::new(response), session).await;
1424    }
1425
1426    let joined_room = {
1427        let room = session
1428            .db()
1429            .await
1430            .join_room(room_id, session.user_id(), session.connection_id)
1431            .await?;
1432        room_updated(&room.room, &session.peer);
1433        room.into_inner()
1434    };
1435
1436    for connection_id in session
1437        .connection_pool()
1438        .await
1439        .user_connection_ids(session.user_id())
1440    {
1441        session
1442            .peer
1443            .send(
1444                connection_id,
1445                proto::CallCanceled {
1446                    room_id: room_id.to_proto(),
1447                },
1448            )
1449            .trace_err();
1450    }
1451
1452    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1453    {
1454        live_kit
1455            .room_token(
1456                &joined_room.room.livekit_room,
1457                &session.user_id().to_string(),
1458            )
1459            .trace_err()
1460            .map(|token| proto::LiveKitConnectionInfo {
1461                server_url: live_kit.url().into(),
1462                token,
1463                can_publish: true,
1464            })
1465    } else {
1466        None
1467    };
1468
1469    response.send(proto::JoinRoomResponse {
1470        room: Some(joined_room.room),
1471        channel_id: None,
1472        live_kit_connection_info,
1473    })?;
1474
1475    update_user_contacts(session.user_id(), &session).await?;
1476    Ok(())
1477}
1478
1479/// Rejoin room is used to reconnect to a room after connection errors.
1480async fn rejoin_room(
1481    request: proto::RejoinRoom,
1482    response: Response<proto::RejoinRoom>,
1483    session: MessageContext,
1484) -> Result<()> {
1485    let room;
1486    let channel;
1487    {
1488        let mut rejoined_room = session
1489            .db()
1490            .await
1491            .rejoin_room(request, session.user_id(), session.connection_id)
1492            .await?;
1493
1494        response.send(proto::RejoinRoomResponse {
1495            room: Some(rejoined_room.room.clone()),
1496            reshared_projects: rejoined_room
1497                .reshared_projects
1498                .iter()
1499                .map(|project| proto::ResharedProject {
1500                    id: project.id.to_proto(),
1501                    collaborators: project
1502                        .collaborators
1503                        .iter()
1504                        .map(|collaborator| collaborator.to_proto())
1505                        .collect(),
1506                })
1507                .collect(),
1508            rejoined_projects: rejoined_room
1509                .rejoined_projects
1510                .iter()
1511                .map(|rejoined_project| rejoined_project.to_proto())
1512                .collect(),
1513        })?;
1514        room_updated(&rejoined_room.room, &session.peer);
1515
1516        for project in &rejoined_room.reshared_projects {
1517            for collaborator in &project.collaborators {
1518                session
1519                    .peer
1520                    .send(
1521                        collaborator.connection_id,
1522                        proto::UpdateProjectCollaborator {
1523                            project_id: project.id.to_proto(),
1524                            old_peer_id: Some(project.old_connection_id.into()),
1525                            new_peer_id: Some(session.connection_id.into()),
1526                        },
1527                    )
1528                    .trace_err();
1529            }
1530
1531            broadcast(
1532                Some(session.connection_id),
1533                project
1534                    .collaborators
1535                    .iter()
1536                    .map(|collaborator| collaborator.connection_id),
1537                |connection_id| {
1538                    session.peer.forward_send(
1539                        session.connection_id,
1540                        connection_id,
1541                        proto::UpdateProject {
1542                            project_id: project.id.to_proto(),
1543                            worktrees: project.worktrees.clone(),
1544                        },
1545                    )
1546                },
1547            );
1548        }
1549
1550        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1551
1552        let rejoined_room = rejoined_room.into_inner();
1553
1554        room = rejoined_room.room;
1555        channel = rejoined_room.channel;
1556    }
1557
1558    if let Some(channel) = channel {
1559        channel_updated(
1560            &channel,
1561            &room,
1562            &session.peer,
1563            &*session.connection_pool().await,
1564        );
1565    }
1566
1567    update_user_contacts(session.user_id(), &session).await?;
1568    Ok(())
1569}
1570
1571fn notify_rejoined_projects(
1572    rejoined_projects: &mut Vec<RejoinedProject>,
1573    session: &Session,
1574) -> Result<()> {
1575    for project in rejoined_projects.iter() {
1576        for collaborator in &project.collaborators {
1577            session
1578                .peer
1579                .send(
1580                    collaborator.connection_id,
1581                    proto::UpdateProjectCollaborator {
1582                        project_id: project.id.to_proto(),
1583                        old_peer_id: Some(project.old_connection_id.into()),
1584                        new_peer_id: Some(session.connection_id.into()),
1585                    },
1586                )
1587                .trace_err();
1588        }
1589    }
1590
1591    for project in rejoined_projects {
1592        for worktree in mem::take(&mut project.worktrees) {
1593            // Stream this worktree's entries.
1594            let message = proto::UpdateWorktree {
1595                project_id: project.id.to_proto(),
1596                worktree_id: worktree.id,
1597                abs_path: worktree.abs_path.clone(),
1598                root_name: worktree.root_name,
1599                updated_entries: worktree.updated_entries,
1600                removed_entries: worktree.removed_entries,
1601                scan_id: worktree.scan_id,
1602                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1603                updated_repositories: worktree.updated_repositories,
1604                removed_repositories: worktree.removed_repositories,
1605            };
1606            for update in proto::split_worktree_update(message) {
1607                session.peer.send(session.connection_id, update)?;
1608            }
1609
1610            // Stream this worktree's diagnostics.
1611            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1612            if let Some(summary) = worktree_diagnostics.next() {
1613                let message = proto::UpdateDiagnosticSummary {
1614                    project_id: project.id.to_proto(),
1615                    worktree_id: worktree.id,
1616                    summary: Some(summary),
1617                    more_summaries: worktree_diagnostics.collect(),
1618                };
1619                session.peer.send(session.connection_id, message)?;
1620            }
1621
1622            for settings_file in worktree.settings_files {
1623                session.peer.send(
1624                    session.connection_id,
1625                    proto::UpdateWorktreeSettings {
1626                        project_id: project.id.to_proto(),
1627                        worktree_id: worktree.id,
1628                        path: settings_file.path,
1629                        content: Some(settings_file.content),
1630                        kind: Some(settings_file.kind.to_proto().into()),
1631                    },
1632                )?;
1633            }
1634        }
1635
1636        for repository in mem::take(&mut project.updated_repositories) {
1637            for update in split_repository_update(repository) {
1638                session.peer.send(session.connection_id, update)?;
1639            }
1640        }
1641
1642        for id in mem::take(&mut project.removed_repositories) {
1643            session.peer.send(
1644                session.connection_id,
1645                proto::RemoveRepository {
1646                    project_id: project.id.to_proto(),
1647                    id,
1648                },
1649            )?;
1650        }
1651    }
1652
1653    Ok(())
1654}
1655
1656/// leave room disconnects from the room.
1657async fn leave_room(
1658    _: proto::LeaveRoom,
1659    response: Response<proto::LeaveRoom>,
1660    session: MessageContext,
1661) -> Result<()> {
1662    leave_room_for_session(&session, session.connection_id).await?;
1663    response.send(proto::Ack {})?;
1664    Ok(())
1665}
1666
1667/// Updates the permissions of someone else in the room.
1668async fn set_room_participant_role(
1669    request: proto::SetRoomParticipantRole,
1670    response: Response<proto::SetRoomParticipantRole>,
1671    session: MessageContext,
1672) -> Result<()> {
1673    let user_id = UserId::from_proto(request.user_id);
1674    let role = ChannelRole::from(request.role());
1675
1676    let (livekit_room, can_publish) = {
1677        let room = session
1678            .db()
1679            .await
1680            .set_room_participant_role(
1681                session.user_id(),
1682                RoomId::from_proto(request.room_id),
1683                user_id,
1684                role,
1685            )
1686            .await?;
1687
1688        let livekit_room = room.livekit_room.clone();
1689        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1690        room_updated(&room, &session.peer);
1691        (livekit_room, can_publish)
1692    };
1693
1694    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1695        live_kit
1696            .update_participant(
1697                livekit_room.clone(),
1698                request.user_id.to_string(),
1699                livekit_api::proto::ParticipantPermission {
1700                    can_subscribe: true,
1701                    can_publish,
1702                    can_publish_data: can_publish,
1703                    hidden: false,
1704                    recorder: false,
1705                },
1706            )
1707            .await
1708            .trace_err();
1709    }
1710
1711    response.send(proto::Ack {})?;
1712    Ok(())
1713}
1714
1715/// Call someone else into the current room
1716async fn call(
1717    request: proto::Call,
1718    response: Response<proto::Call>,
1719    session: MessageContext,
1720) -> Result<()> {
1721    let room_id = RoomId::from_proto(request.room_id);
1722    let calling_user_id = session.user_id();
1723    let calling_connection_id = session.connection_id;
1724    let called_user_id = UserId::from_proto(request.called_user_id);
1725    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1726    if !session
1727        .db()
1728        .await
1729        .has_contact(calling_user_id, called_user_id)
1730        .await?
1731    {
1732        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1733    }
1734
1735    let incoming_call = {
1736        let (room, incoming_call) = &mut *session
1737            .db()
1738            .await
1739            .call(
1740                room_id,
1741                calling_user_id,
1742                calling_connection_id,
1743                called_user_id,
1744                initial_project_id,
1745            )
1746            .await?;
1747        room_updated(room, &session.peer);
1748        mem::take(incoming_call)
1749    };
1750    update_user_contacts(called_user_id, &session).await?;
1751
1752    let mut calls = session
1753        .connection_pool()
1754        .await
1755        .user_connection_ids(called_user_id)
1756        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1757        .collect::<FuturesUnordered<_>>();
1758
1759    while let Some(call_response) = calls.next().await {
1760        match call_response.as_ref() {
1761            Ok(_) => {
1762                response.send(proto::Ack {})?;
1763                return Ok(());
1764            }
1765            Err(_) => {
1766                call_response.trace_err();
1767            }
1768        }
1769    }
1770
1771    {
1772        let room = session
1773            .db()
1774            .await
1775            .call_failed(room_id, called_user_id)
1776            .await?;
1777        room_updated(&room, &session.peer);
1778    }
1779    update_user_contacts(called_user_id, &session).await?;
1780
1781    Err(anyhow!("failed to ring user"))?
1782}
1783
1784/// Cancel an outgoing call.
1785async fn cancel_call(
1786    request: proto::CancelCall,
1787    response: Response<proto::CancelCall>,
1788    session: MessageContext,
1789) -> Result<()> {
1790    let called_user_id = UserId::from_proto(request.called_user_id);
1791    let room_id = RoomId::from_proto(request.room_id);
1792    {
1793        let room = session
1794            .db()
1795            .await
1796            .cancel_call(room_id, session.connection_id, called_user_id)
1797            .await?;
1798        room_updated(&room, &session.peer);
1799    }
1800
1801    for connection_id in session
1802        .connection_pool()
1803        .await
1804        .user_connection_ids(called_user_id)
1805    {
1806        session
1807            .peer
1808            .send(
1809                connection_id,
1810                proto::CallCanceled {
1811                    room_id: room_id.to_proto(),
1812                },
1813            )
1814            .trace_err();
1815    }
1816    response.send(proto::Ack {})?;
1817
1818    update_user_contacts(called_user_id, &session).await?;
1819    Ok(())
1820}
1821
1822/// Decline an incoming call.
1823async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1824    let room_id = RoomId::from_proto(message.room_id);
1825    {
1826        let room = session
1827            .db()
1828            .await
1829            .decline_call(Some(room_id), session.user_id())
1830            .await?
1831            .context("declining call")?;
1832        room_updated(&room, &session.peer);
1833    }
1834
1835    for connection_id in session
1836        .connection_pool()
1837        .await
1838        .user_connection_ids(session.user_id())
1839    {
1840        session
1841            .peer
1842            .send(
1843                connection_id,
1844                proto::CallCanceled {
1845                    room_id: room_id.to_proto(),
1846                },
1847            )
1848            .trace_err();
1849    }
1850    update_user_contacts(session.user_id(), &session).await?;
1851    Ok(())
1852}
1853
1854/// Updates other participants in the room with your current location.
1855async fn update_participant_location(
1856    request: proto::UpdateParticipantLocation,
1857    response: Response<proto::UpdateParticipantLocation>,
1858    session: MessageContext,
1859) -> Result<()> {
1860    let room_id = RoomId::from_proto(request.room_id);
1861    let location = request.location.context("invalid location")?;
1862
1863    let db = session.db().await;
1864    let room = db
1865        .update_room_participant_location(room_id, session.connection_id, location)
1866        .await?;
1867
1868    room_updated(&room, &session.peer);
1869    response.send(proto::Ack {})?;
1870    Ok(())
1871}
1872
1873/// Share a project into the room.
1874async fn share_project(
1875    request: proto::ShareProject,
1876    response: Response<proto::ShareProject>,
1877    session: MessageContext,
1878) -> Result<()> {
1879    let (project_id, room) = &*session
1880        .db()
1881        .await
1882        .share_project(
1883            RoomId::from_proto(request.room_id),
1884            session.connection_id,
1885            &request.worktrees,
1886            request.is_ssh_project,
1887            request.windows_paths.unwrap_or(false),
1888        )
1889        .await?;
1890    response.send(proto::ShareProjectResponse {
1891        project_id: project_id.to_proto(),
1892    })?;
1893    room_updated(room, &session.peer);
1894
1895    Ok(())
1896}
1897
1898/// Unshare a project from the room.
1899async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1900    let project_id = ProjectId::from_proto(message.project_id);
1901    unshare_project_internal(project_id, session.connection_id, &session).await
1902}
1903
1904async fn unshare_project_internal(
1905    project_id: ProjectId,
1906    connection_id: ConnectionId,
1907    session: &Session,
1908) -> Result<()> {
1909    let delete = {
1910        let room_guard = session
1911            .db()
1912            .await
1913            .unshare_project(project_id, connection_id)
1914            .await?;
1915
1916        let (delete, room, guest_connection_ids) = &*room_guard;
1917
1918        let message = proto::UnshareProject {
1919            project_id: project_id.to_proto(),
1920        };
1921
1922        broadcast(
1923            Some(connection_id),
1924            guest_connection_ids.iter().copied(),
1925            |conn_id| session.peer.send(conn_id, message.clone()),
1926        );
1927        if let Some(room) = room {
1928            room_updated(room, &session.peer);
1929        }
1930
1931        *delete
1932    };
1933
1934    if delete {
1935        let db = session.db().await;
1936        db.delete_project(project_id).await?;
1937    }
1938
1939    Ok(())
1940}
1941
1942/// Join someone elses shared project.
1943async fn join_project(
1944    request: proto::JoinProject,
1945    response: Response<proto::JoinProject>,
1946    session: MessageContext,
1947) -> Result<()> {
1948    let project_id = ProjectId::from_proto(request.project_id);
1949
1950    tracing::info!(%project_id, "join project");
1951
1952    let db = session.db().await;
1953    let (project, replica_id) = &mut *db
1954        .join_project(
1955            project_id,
1956            session.connection_id,
1957            session.user_id(),
1958            request.committer_name.clone(),
1959            request.committer_email.clone(),
1960        )
1961        .await?;
1962    drop(db);
1963    tracing::info!(%project_id, "join remote project");
1964    let collaborators = project
1965        .collaborators
1966        .iter()
1967        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1968        .map(|collaborator| collaborator.to_proto())
1969        .collect::<Vec<_>>();
1970    let project_id = project.id;
1971    let guest_user_id = session.user_id();
1972
1973    let worktrees = project
1974        .worktrees
1975        .iter()
1976        .map(|(id, worktree)| proto::WorktreeMetadata {
1977            id: *id,
1978            root_name: worktree.root_name.clone(),
1979            visible: worktree.visible,
1980            abs_path: worktree.abs_path.clone(),
1981        })
1982        .collect::<Vec<_>>();
1983
1984    let add_project_collaborator = proto::AddProjectCollaborator {
1985        project_id: project_id.to_proto(),
1986        collaborator: Some(proto::Collaborator {
1987            peer_id: Some(session.connection_id.into()),
1988            replica_id: replica_id.0 as u32,
1989            user_id: guest_user_id.to_proto(),
1990            is_host: false,
1991            committer_name: request.committer_name.clone(),
1992            committer_email: request.committer_email.clone(),
1993        }),
1994    };
1995
1996    for collaborator in &collaborators {
1997        session
1998            .peer
1999            .send(
2000                collaborator.peer_id.unwrap().into(),
2001                add_project_collaborator.clone(),
2002            )
2003            .trace_err();
2004    }
2005
2006    // First, we send the metadata associated with each worktree.
2007    let (language_servers, language_server_capabilities) = project
2008        .language_servers
2009        .clone()
2010        .into_iter()
2011        .map(|server| (server.server, server.capabilities))
2012        .unzip();
2013    response.send(proto::JoinProjectResponse {
2014        project_id: project.id.0 as u64,
2015        worktrees,
2016        replica_id: replica_id.0 as u32,
2017        collaborators,
2018        language_servers,
2019        language_server_capabilities,
2020        role: project.role.into(),
2021        windows_paths: project.path_style == PathStyle::Windows,
2022    })?;
2023
2024    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2025        // Stream this worktree's entries.
2026        let message = proto::UpdateWorktree {
2027            project_id: project_id.to_proto(),
2028            worktree_id,
2029            abs_path: worktree.abs_path.clone(),
2030            root_name: worktree.root_name,
2031            updated_entries: worktree.entries,
2032            removed_entries: Default::default(),
2033            scan_id: worktree.scan_id,
2034            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2035            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2036            removed_repositories: Default::default(),
2037        };
2038        for update in proto::split_worktree_update(message) {
2039            session.peer.send(session.connection_id, update.clone())?;
2040        }
2041
2042        // Stream this worktree's diagnostics.
2043        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2044        if let Some(summary) = worktree_diagnostics.next() {
2045            let message = proto::UpdateDiagnosticSummary {
2046                project_id: project.id.to_proto(),
2047                worktree_id: worktree.id,
2048                summary: Some(summary),
2049                more_summaries: worktree_diagnostics.collect(),
2050            };
2051            session.peer.send(session.connection_id, message)?;
2052        }
2053
2054        for settings_file in worktree.settings_files {
2055            session.peer.send(
2056                session.connection_id,
2057                proto::UpdateWorktreeSettings {
2058                    project_id: project_id.to_proto(),
2059                    worktree_id: worktree.id,
2060                    path: settings_file.path,
2061                    content: Some(settings_file.content),
2062                    kind: Some(settings_file.kind.to_proto() as i32),
2063                },
2064            )?;
2065        }
2066    }
2067
2068    for repository in mem::take(&mut project.repositories) {
2069        for update in split_repository_update(repository) {
2070            session.peer.send(session.connection_id, update)?;
2071        }
2072    }
2073
2074    for language_server in &project.language_servers {
2075        session.peer.send(
2076            session.connection_id,
2077            proto::UpdateLanguageServer {
2078                project_id: project_id.to_proto(),
2079                server_name: Some(language_server.server.name.clone()),
2080                language_server_id: language_server.server.id,
2081                variant: Some(
2082                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2083                        proto::LspDiskBasedDiagnosticsUpdated {},
2084                    ),
2085                ),
2086            },
2087        )?;
2088    }
2089
2090    Ok(())
2091}
2092
2093/// Leave someone elses shared project.
2094async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2095    let sender_id = session.connection_id;
2096    let project_id = ProjectId::from_proto(request.project_id);
2097    let db = session.db().await;
2098
2099    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2100    tracing::info!(
2101        %project_id,
2102        "leave project"
2103    );
2104
2105    project_left(project, &session);
2106    if let Some(room) = room {
2107        room_updated(room, &session.peer);
2108    }
2109
2110    Ok(())
2111}
2112
2113/// Updates other participants with changes to the project
2114async fn update_project(
2115    request: proto::UpdateProject,
2116    response: Response<proto::UpdateProject>,
2117    session: MessageContext,
2118) -> Result<()> {
2119    let project_id = ProjectId::from_proto(request.project_id);
2120    let (room, guest_connection_ids) = &*session
2121        .db()
2122        .await
2123        .update_project(project_id, session.connection_id, &request.worktrees)
2124        .await?;
2125    broadcast(
2126        Some(session.connection_id),
2127        guest_connection_ids.iter().copied(),
2128        |connection_id| {
2129            session
2130                .peer
2131                .forward_send(session.connection_id, connection_id, request.clone())
2132        },
2133    );
2134    if let Some(room) = room {
2135        room_updated(room, &session.peer);
2136    }
2137    response.send(proto::Ack {})?;
2138
2139    Ok(())
2140}
2141
2142/// Updates other participants with changes to the worktree
2143async fn update_worktree(
2144    request: proto::UpdateWorktree,
2145    response: Response<proto::UpdateWorktree>,
2146    session: MessageContext,
2147) -> Result<()> {
2148    let guest_connection_ids = session
2149        .db()
2150        .await
2151        .update_worktree(&request, session.connection_id)
2152        .await?;
2153
2154    broadcast(
2155        Some(session.connection_id),
2156        guest_connection_ids.iter().copied(),
2157        |connection_id| {
2158            session
2159                .peer
2160                .forward_send(session.connection_id, connection_id, request.clone())
2161        },
2162    );
2163    response.send(proto::Ack {})?;
2164    Ok(())
2165}
2166
2167async fn update_repository(
2168    request: proto::UpdateRepository,
2169    response: Response<proto::UpdateRepository>,
2170    session: MessageContext,
2171) -> Result<()> {
2172    let guest_connection_ids = session
2173        .db()
2174        .await
2175        .update_repository(&request, session.connection_id)
2176        .await?;
2177
2178    broadcast(
2179        Some(session.connection_id),
2180        guest_connection_ids.iter().copied(),
2181        |connection_id| {
2182            session
2183                .peer
2184                .forward_send(session.connection_id, connection_id, request.clone())
2185        },
2186    );
2187    response.send(proto::Ack {})?;
2188    Ok(())
2189}
2190
2191async fn remove_repository(
2192    request: proto::RemoveRepository,
2193    response: Response<proto::RemoveRepository>,
2194    session: MessageContext,
2195) -> Result<()> {
2196    let guest_connection_ids = session
2197        .db()
2198        .await
2199        .remove_repository(&request, session.connection_id)
2200        .await?;
2201
2202    broadcast(
2203        Some(session.connection_id),
2204        guest_connection_ids.iter().copied(),
2205        |connection_id| {
2206            session
2207                .peer
2208                .forward_send(session.connection_id, connection_id, request.clone())
2209        },
2210    );
2211    response.send(proto::Ack {})?;
2212    Ok(())
2213}
2214
2215/// Updates other participants with changes to the diagnostics
2216async fn update_diagnostic_summary(
2217    message: proto::UpdateDiagnosticSummary,
2218    session: MessageContext,
2219) -> Result<()> {
2220    let guest_connection_ids = session
2221        .db()
2222        .await
2223        .update_diagnostic_summary(&message, session.connection_id)
2224        .await?;
2225
2226    broadcast(
2227        Some(session.connection_id),
2228        guest_connection_ids.iter().copied(),
2229        |connection_id| {
2230            session
2231                .peer
2232                .forward_send(session.connection_id, connection_id, message.clone())
2233        },
2234    );
2235
2236    Ok(())
2237}
2238
2239/// Updates other participants with changes to the worktree settings
2240async fn update_worktree_settings(
2241    message: proto::UpdateWorktreeSettings,
2242    session: MessageContext,
2243) -> Result<()> {
2244    let guest_connection_ids = session
2245        .db()
2246        .await
2247        .update_worktree_settings(&message, session.connection_id)
2248        .await?;
2249
2250    broadcast(
2251        Some(session.connection_id),
2252        guest_connection_ids.iter().copied(),
2253        |connection_id| {
2254            session
2255                .peer
2256                .forward_send(session.connection_id, connection_id, message.clone())
2257        },
2258    );
2259
2260    Ok(())
2261}
2262
2263/// Notify other participants that a language server has started.
2264async fn start_language_server(
2265    request: proto::StartLanguageServer,
2266    session: MessageContext,
2267) -> Result<()> {
2268    let guest_connection_ids = session
2269        .db()
2270        .await
2271        .start_language_server(&request, session.connection_id)
2272        .await?;
2273
2274    broadcast(
2275        Some(session.connection_id),
2276        guest_connection_ids.iter().copied(),
2277        |connection_id| {
2278            session
2279                .peer
2280                .forward_send(session.connection_id, connection_id, request.clone())
2281        },
2282    );
2283    Ok(())
2284}
2285
2286/// Notify other participants that a language server has changed.
2287async fn update_language_server(
2288    request: proto::UpdateLanguageServer,
2289    session: MessageContext,
2290) -> Result<()> {
2291    let project_id = ProjectId::from_proto(request.project_id);
2292    let db = session.db().await;
2293
2294    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2295        && let Some(capabilities) = update.capabilities.clone()
2296    {
2297        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2298            .await?;
2299    }
2300
2301    let project_connection_ids = db
2302        .project_connection_ids(project_id, session.connection_id, true)
2303        .await?;
2304    broadcast(
2305        Some(session.connection_id),
2306        project_connection_ids.iter().copied(),
2307        |connection_id| {
2308            session
2309                .peer
2310                .forward_send(session.connection_id, connection_id, request.clone())
2311        },
2312    );
2313    Ok(())
2314}
2315
2316/// forward a project request to the host. These requests should be read only
2317/// as guests are allowed to send them.
2318async fn forward_read_only_project_request<T>(
2319    request: T,
2320    response: Response<T>,
2321    session: MessageContext,
2322) -> Result<()>
2323where
2324    T: EntityMessage + RequestMessage,
2325{
2326    let project_id = ProjectId::from_proto(request.remote_entity_id());
2327    let host_connection_id = session
2328        .db()
2329        .await
2330        .host_for_read_only_project_request(project_id, session.connection_id)
2331        .await?;
2332    let payload = session.forward_request(host_connection_id, request).await?;
2333    response.send(payload)?;
2334    Ok(())
2335}
2336
2337/// forward a project request to the host. These requests are disallowed
2338/// for guests.
2339async fn forward_mutating_project_request<T>(
2340    request: T,
2341    response: Response<T>,
2342    session: MessageContext,
2343) -> Result<()>
2344where
2345    T: EntityMessage + RequestMessage,
2346{
2347    let project_id = ProjectId::from_proto(request.remote_entity_id());
2348
2349    let host_connection_id = session
2350        .db()
2351        .await
2352        .host_for_mutating_project_request(project_id, session.connection_id)
2353        .await?;
2354    let payload = session.forward_request(host_connection_id, request).await?;
2355    response.send(payload)?;
2356    Ok(())
2357}
2358
2359async fn lsp_query(
2360    request: proto::LspQuery,
2361    response: Response<proto::LspQuery>,
2362    session: MessageContext,
2363) -> Result<()> {
2364    let (name, should_write) = request.query_name_and_write_permissions();
2365    tracing::Span::current().record("lsp_query_request", name);
2366    tracing::info!("lsp_query message received");
2367    if should_write {
2368        forward_mutating_project_request(request, response, session).await
2369    } else {
2370        forward_read_only_project_request(request, response, session).await
2371    }
2372}
2373
2374/// Notify other participants that a new buffer has been created
2375async fn create_buffer_for_peer(
2376    request: proto::CreateBufferForPeer,
2377    session: MessageContext,
2378) -> Result<()> {
2379    session
2380        .db()
2381        .await
2382        .check_user_is_project_host(
2383            ProjectId::from_proto(request.project_id),
2384            session.connection_id,
2385        )
2386        .await?;
2387    let peer_id = request.peer_id.context("invalid peer id")?;
2388    session
2389        .peer
2390        .forward_send(session.connection_id, peer_id.into(), request)?;
2391    Ok(())
2392}
2393
2394/// Notify other participants that a new image has been created
2395async fn create_image_for_peer(
2396    request: proto::CreateImageForPeer,
2397    session: MessageContext,
2398) -> Result<()> {
2399    session
2400        .db()
2401        .await
2402        .check_user_is_project_host(
2403            ProjectId::from_proto(request.project_id),
2404            session.connection_id,
2405        )
2406        .await?;
2407    let peer_id = request.peer_id.context("invalid peer id")?;
2408    session
2409        .peer
2410        .forward_send(session.connection_id, peer_id.into(), request)?;
2411    Ok(())
2412}
2413
2414/// Notify other participants that a buffer has been updated. This is
2415/// allowed for guests as long as the update is limited to selections.
2416async fn update_buffer(
2417    request: proto::UpdateBuffer,
2418    response: Response<proto::UpdateBuffer>,
2419    session: MessageContext,
2420) -> Result<()> {
2421    let project_id = ProjectId::from_proto(request.project_id);
2422    let mut capability = Capability::ReadOnly;
2423
2424    for op in request.operations.iter() {
2425        match op.variant {
2426            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2427            Some(_) => capability = Capability::ReadWrite,
2428        }
2429    }
2430
2431    let host = {
2432        let guard = session
2433            .db()
2434            .await
2435            .connections_for_buffer_update(project_id, session.connection_id, capability)
2436            .await?;
2437
2438        let (host, guests) = &*guard;
2439
2440        broadcast(
2441            Some(session.connection_id),
2442            guests.clone(),
2443            |connection_id| {
2444                session
2445                    .peer
2446                    .forward_send(session.connection_id, connection_id, request.clone())
2447            },
2448        );
2449
2450        *host
2451    };
2452
2453    if host != session.connection_id {
2454        session.forward_request(host, request.clone()).await?;
2455    }
2456
2457    response.send(proto::Ack {})?;
2458    Ok(())
2459}
2460
2461async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2462    let project_id = ProjectId::from_proto(message.project_id);
2463
2464    let operation = message.operation.as_ref().context("invalid operation")?;
2465    let capability = match operation.variant.as_ref() {
2466        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2467            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2468                match buffer_op.variant {
2469                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2470                        Capability::ReadOnly
2471                    }
2472                    _ => Capability::ReadWrite,
2473                }
2474            } else {
2475                Capability::ReadWrite
2476            }
2477        }
2478        Some(_) => Capability::ReadWrite,
2479        None => Capability::ReadOnly,
2480    };
2481
2482    let guard = session
2483        .db()
2484        .await
2485        .connections_for_buffer_update(project_id, session.connection_id, capability)
2486        .await?;
2487
2488    let (host, guests) = &*guard;
2489
2490    broadcast(
2491        Some(session.connection_id),
2492        guests.iter().chain([host]).copied(),
2493        |connection_id| {
2494            session
2495                .peer
2496                .forward_send(session.connection_id, connection_id, message.clone())
2497        },
2498    );
2499
2500    Ok(())
2501}
2502
2503/// Notify other participants that a project has been updated.
2504async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2505    request: T,
2506    session: MessageContext,
2507) -> Result<()> {
2508    let project_id = ProjectId::from_proto(request.remote_entity_id());
2509    let project_connection_ids = session
2510        .db()
2511        .await
2512        .project_connection_ids(project_id, session.connection_id, false)
2513        .await?;
2514
2515    broadcast(
2516        Some(session.connection_id),
2517        project_connection_ids.iter().copied(),
2518        |connection_id| {
2519            session
2520                .peer
2521                .forward_send(session.connection_id, connection_id, request.clone())
2522        },
2523    );
2524    Ok(())
2525}
2526
2527/// Start following another user in a call.
2528async fn follow(
2529    request: proto::Follow,
2530    response: Response<proto::Follow>,
2531    session: MessageContext,
2532) -> Result<()> {
2533    let room_id = RoomId::from_proto(request.room_id);
2534    let project_id = request.project_id.map(ProjectId::from_proto);
2535    let leader_id = request.leader_id.context("invalid leader id")?.into();
2536    let follower_id = session.connection_id;
2537
2538    session
2539        .db()
2540        .await
2541        .check_room_participants(room_id, leader_id, session.connection_id)
2542        .await?;
2543
2544    let response_payload = session.forward_request(leader_id, request).await?;
2545    response.send(response_payload)?;
2546
2547    if let Some(project_id) = project_id {
2548        let room = session
2549            .db()
2550            .await
2551            .follow(room_id, project_id, leader_id, follower_id)
2552            .await?;
2553        room_updated(&room, &session.peer);
2554    }
2555
2556    Ok(())
2557}
2558
2559/// Stop following another user in a call.
2560async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2561    let room_id = RoomId::from_proto(request.room_id);
2562    let project_id = request.project_id.map(ProjectId::from_proto);
2563    let leader_id = request.leader_id.context("invalid leader id")?.into();
2564    let follower_id = session.connection_id;
2565
2566    session
2567        .db()
2568        .await
2569        .check_room_participants(room_id, leader_id, session.connection_id)
2570        .await?;
2571
2572    session
2573        .peer
2574        .forward_send(session.connection_id, leader_id, request)?;
2575
2576    if let Some(project_id) = project_id {
2577        let room = session
2578            .db()
2579            .await
2580            .unfollow(room_id, project_id, leader_id, follower_id)
2581            .await?;
2582        room_updated(&room, &session.peer);
2583    }
2584
2585    Ok(())
2586}
2587
2588/// Notify everyone following you of your current location.
2589async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2590    let room_id = RoomId::from_proto(request.room_id);
2591    let database = session.db.lock().await;
2592
2593    let connection_ids = if let Some(project_id) = request.project_id {
2594        let project_id = ProjectId::from_proto(project_id);
2595        database
2596            .project_connection_ids(project_id, session.connection_id, true)
2597            .await?
2598    } else {
2599        database
2600            .room_connection_ids(room_id, session.connection_id)
2601            .await?
2602    };
2603
2604    // For now, don't send view update messages back to that view's current leader.
2605    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2606        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2607        _ => None,
2608    });
2609
2610    for connection_id in connection_ids.iter().cloned() {
2611        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2612            session
2613                .peer
2614                .forward_send(session.connection_id, connection_id, request.clone())?;
2615        }
2616    }
2617    Ok(())
2618}
2619
2620/// Get public data about users.
2621async fn get_users(
2622    request: proto::GetUsers,
2623    response: Response<proto::GetUsers>,
2624    session: MessageContext,
2625) -> Result<()> {
2626    let user_ids = request
2627        .user_ids
2628        .into_iter()
2629        .map(UserId::from_proto)
2630        .collect();
2631    let users = session
2632        .db()
2633        .await
2634        .get_users_by_ids(user_ids)
2635        .await?
2636        .into_iter()
2637        .map(|user| proto::User {
2638            id: user.id.to_proto(),
2639            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2640            github_login: user.github_login,
2641            name: user.name,
2642        })
2643        .collect();
2644    response.send(proto::UsersResponse { users })?;
2645    Ok(())
2646}
2647
2648/// Search for users (to invite) buy Github login
2649async fn fuzzy_search_users(
2650    request: proto::FuzzySearchUsers,
2651    response: Response<proto::FuzzySearchUsers>,
2652    session: MessageContext,
2653) -> Result<()> {
2654    let query = request.query;
2655    let users = match query.len() {
2656        0 => vec![],
2657        1 | 2 => session
2658            .db()
2659            .await
2660            .get_user_by_github_login(&query)
2661            .await?
2662            .into_iter()
2663            .collect(),
2664        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2665    };
2666    let users = users
2667        .into_iter()
2668        .filter(|user| user.id != session.user_id())
2669        .map(|user| proto::User {
2670            id: user.id.to_proto(),
2671            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2672            github_login: user.github_login,
2673            name: user.name,
2674        })
2675        .collect();
2676    response.send(proto::UsersResponse { users })?;
2677    Ok(())
2678}
2679
2680/// Send a contact request to another user.
2681async fn request_contact(
2682    request: proto::RequestContact,
2683    response: Response<proto::RequestContact>,
2684    session: MessageContext,
2685) -> Result<()> {
2686    let requester_id = session.user_id();
2687    let responder_id = UserId::from_proto(request.responder_id);
2688    if requester_id == responder_id {
2689        return Err(anyhow!("cannot add yourself as a contact"))?;
2690    }
2691
2692    let notifications = session
2693        .db()
2694        .await
2695        .send_contact_request(requester_id, responder_id)
2696        .await?;
2697
2698    // Update outgoing contact requests of requester
2699    let mut update = proto::UpdateContacts::default();
2700    update.outgoing_requests.push(responder_id.to_proto());
2701    for connection_id in session
2702        .connection_pool()
2703        .await
2704        .user_connection_ids(requester_id)
2705    {
2706        session.peer.send(connection_id, update.clone())?;
2707    }
2708
2709    // Update incoming contact requests of responder
2710    let mut update = proto::UpdateContacts::default();
2711    update
2712        .incoming_requests
2713        .push(proto::IncomingContactRequest {
2714            requester_id: requester_id.to_proto(),
2715        });
2716    let connection_pool = session.connection_pool().await;
2717    for connection_id in connection_pool.user_connection_ids(responder_id) {
2718        session.peer.send(connection_id, update.clone())?;
2719    }
2720
2721    send_notifications(&connection_pool, &session.peer, notifications);
2722
2723    response.send(proto::Ack {})?;
2724    Ok(())
2725}
2726
2727/// Accept or decline a contact request
2728async fn respond_to_contact_request(
2729    request: proto::RespondToContactRequest,
2730    response: Response<proto::RespondToContactRequest>,
2731    session: MessageContext,
2732) -> Result<()> {
2733    let responder_id = session.user_id();
2734    let requester_id = UserId::from_proto(request.requester_id);
2735    let db = session.db().await;
2736    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2737        db.dismiss_contact_notification(responder_id, requester_id)
2738            .await?;
2739    } else {
2740        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2741
2742        let notifications = db
2743            .respond_to_contact_request(responder_id, requester_id, accept)
2744            .await?;
2745        let requester_busy = db.is_user_busy(requester_id).await?;
2746        let responder_busy = db.is_user_busy(responder_id).await?;
2747
2748        let pool = session.connection_pool().await;
2749        // Update responder with new contact
2750        let mut update = proto::UpdateContacts::default();
2751        if accept {
2752            update
2753                .contacts
2754                .push(contact_for_user(requester_id, requester_busy, &pool));
2755        }
2756        update
2757            .remove_incoming_requests
2758            .push(requester_id.to_proto());
2759        for connection_id in pool.user_connection_ids(responder_id) {
2760            session.peer.send(connection_id, update.clone())?;
2761        }
2762
2763        // Update requester with new contact
2764        let mut update = proto::UpdateContacts::default();
2765        if accept {
2766            update
2767                .contacts
2768                .push(contact_for_user(responder_id, responder_busy, &pool));
2769        }
2770        update
2771            .remove_outgoing_requests
2772            .push(responder_id.to_proto());
2773
2774        for connection_id in pool.user_connection_ids(requester_id) {
2775            session.peer.send(connection_id, update.clone())?;
2776        }
2777
2778        send_notifications(&pool, &session.peer, notifications);
2779    }
2780
2781    response.send(proto::Ack {})?;
2782    Ok(())
2783}
2784
2785/// Remove a contact.
2786async fn remove_contact(
2787    request: proto::RemoveContact,
2788    response: Response<proto::RemoveContact>,
2789    session: MessageContext,
2790) -> Result<()> {
2791    let requester_id = session.user_id();
2792    let responder_id = UserId::from_proto(request.user_id);
2793    let db = session.db().await;
2794    let (contact_accepted, deleted_notification_id) =
2795        db.remove_contact(requester_id, responder_id).await?;
2796
2797    let pool = session.connection_pool().await;
2798    // Update outgoing contact requests of requester
2799    let mut update = proto::UpdateContacts::default();
2800    if contact_accepted {
2801        update.remove_contacts.push(responder_id.to_proto());
2802    } else {
2803        update
2804            .remove_outgoing_requests
2805            .push(responder_id.to_proto());
2806    }
2807    for connection_id in pool.user_connection_ids(requester_id) {
2808        session.peer.send(connection_id, update.clone())?;
2809    }
2810
2811    // Update incoming contact requests of responder
2812    let mut update = proto::UpdateContacts::default();
2813    if contact_accepted {
2814        update.remove_contacts.push(requester_id.to_proto());
2815    } else {
2816        update
2817            .remove_incoming_requests
2818            .push(requester_id.to_proto());
2819    }
2820    for connection_id in pool.user_connection_ids(responder_id) {
2821        session.peer.send(connection_id, update.clone())?;
2822        if let Some(notification_id) = deleted_notification_id {
2823            session.peer.send(
2824                connection_id,
2825                proto::DeleteNotification {
2826                    notification_id: notification_id.to_proto(),
2827                },
2828            )?;
2829        }
2830    }
2831
2832    response.send(proto::Ack {})?;
2833    Ok(())
2834}
2835
2836fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2837    version.0.minor() < 139
2838}
2839
2840async fn subscribe_to_channels(
2841    _: proto::SubscribeToChannels,
2842    session: MessageContext,
2843) -> Result<()> {
2844    subscribe_user_to_channels(session.user_id(), &session).await?;
2845    Ok(())
2846}
2847
2848async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2849    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2850    let mut pool = session.connection_pool().await;
2851    for membership in &channels_for_user.channel_memberships {
2852        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2853    }
2854    session.peer.send(
2855        session.connection_id,
2856        build_update_user_channels(&channels_for_user),
2857    )?;
2858    session.peer.send(
2859        session.connection_id,
2860        build_channels_update(channels_for_user),
2861    )?;
2862    Ok(())
2863}
2864
2865/// Creates a new channel.
2866async fn create_channel(
2867    request: proto::CreateChannel,
2868    response: Response<proto::CreateChannel>,
2869    session: MessageContext,
2870) -> Result<()> {
2871    let db = session.db().await;
2872
2873    let parent_id = request.parent_id.map(ChannelId::from_proto);
2874    let (channel, membership) = db
2875        .create_channel(&request.name, parent_id, session.user_id())
2876        .await?;
2877
2878    let root_id = channel.root_id();
2879    let channel = Channel::from_model(channel);
2880
2881    response.send(proto::CreateChannelResponse {
2882        channel: Some(channel.to_proto()),
2883        parent_id: request.parent_id,
2884    })?;
2885
2886    let mut connection_pool = session.connection_pool().await;
2887    if let Some(membership) = membership {
2888        connection_pool.subscribe_to_channel(
2889            membership.user_id,
2890            membership.channel_id,
2891            membership.role,
2892        );
2893        let update = proto::UpdateUserChannels {
2894            channel_memberships: vec![proto::ChannelMembership {
2895                channel_id: membership.channel_id.to_proto(),
2896                role: membership.role.into(),
2897            }],
2898            ..Default::default()
2899        };
2900        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2901            session.peer.send(connection_id, update.clone())?;
2902        }
2903    }
2904
2905    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2906        if !role.can_see_channel(channel.visibility) {
2907            continue;
2908        }
2909
2910        let update = proto::UpdateChannels {
2911            channels: vec![channel.to_proto()],
2912            ..Default::default()
2913        };
2914        session.peer.send(connection_id, update.clone())?;
2915    }
2916
2917    Ok(())
2918}
2919
2920/// Delete a channel
2921async fn delete_channel(
2922    request: proto::DeleteChannel,
2923    response: Response<proto::DeleteChannel>,
2924    session: MessageContext,
2925) -> Result<()> {
2926    let db = session.db().await;
2927
2928    let channel_id = request.channel_id;
2929    let (root_channel, removed_channels) = db
2930        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2931        .await?;
2932    response.send(proto::Ack {})?;
2933
2934    // Notify members of removed channels
2935    let mut update = proto::UpdateChannels::default();
2936    update
2937        .delete_channels
2938        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2939
2940    let connection_pool = session.connection_pool().await;
2941    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2942        session.peer.send(connection_id, update.clone())?;
2943    }
2944
2945    Ok(())
2946}
2947
2948/// Invite someone to join a channel.
2949async fn invite_channel_member(
2950    request: proto::InviteChannelMember,
2951    response: Response<proto::InviteChannelMember>,
2952    session: MessageContext,
2953) -> Result<()> {
2954    let db = session.db().await;
2955    let channel_id = ChannelId::from_proto(request.channel_id);
2956    let invitee_id = UserId::from_proto(request.user_id);
2957    let InviteMemberResult {
2958        channel,
2959        notifications,
2960    } = db
2961        .invite_channel_member(
2962            channel_id,
2963            invitee_id,
2964            session.user_id(),
2965            request.role().into(),
2966        )
2967        .await?;
2968
2969    let update = proto::UpdateChannels {
2970        channel_invitations: vec![channel.to_proto()],
2971        ..Default::default()
2972    };
2973
2974    let connection_pool = session.connection_pool().await;
2975    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2976        session.peer.send(connection_id, update.clone())?;
2977    }
2978
2979    send_notifications(&connection_pool, &session.peer, notifications);
2980
2981    response.send(proto::Ack {})?;
2982    Ok(())
2983}
2984
2985/// remove someone from a channel
2986async fn remove_channel_member(
2987    request: proto::RemoveChannelMember,
2988    response: Response<proto::RemoveChannelMember>,
2989    session: MessageContext,
2990) -> Result<()> {
2991    let db = session.db().await;
2992    let channel_id = ChannelId::from_proto(request.channel_id);
2993    let member_id = UserId::from_proto(request.user_id);
2994
2995    let RemoveChannelMemberResult {
2996        membership_update,
2997        notification_id,
2998    } = db
2999        .remove_channel_member(channel_id, member_id, session.user_id())
3000        .await?;
3001
3002    let mut connection_pool = session.connection_pool().await;
3003    notify_membership_updated(
3004        &mut connection_pool,
3005        membership_update,
3006        member_id,
3007        &session.peer,
3008    );
3009    for connection_id in connection_pool.user_connection_ids(member_id) {
3010        if let Some(notification_id) = notification_id {
3011            session
3012                .peer
3013                .send(
3014                    connection_id,
3015                    proto::DeleteNotification {
3016                        notification_id: notification_id.to_proto(),
3017                    },
3018                )
3019                .trace_err();
3020        }
3021    }
3022
3023    response.send(proto::Ack {})?;
3024    Ok(())
3025}
3026
3027/// Toggle the channel between public and private.
3028/// Care is taken to maintain the invariant that public channels only descend from public channels,
3029/// (though members-only channels can appear at any point in the hierarchy).
3030async fn set_channel_visibility(
3031    request: proto::SetChannelVisibility,
3032    response: Response<proto::SetChannelVisibility>,
3033    session: MessageContext,
3034) -> Result<()> {
3035    let db = session.db().await;
3036    let channel_id = ChannelId::from_proto(request.channel_id);
3037    let visibility = request.visibility().into();
3038
3039    let channel_model = db
3040        .set_channel_visibility(channel_id, visibility, session.user_id())
3041        .await?;
3042    let root_id = channel_model.root_id();
3043    let channel = Channel::from_model(channel_model);
3044
3045    let mut connection_pool = session.connection_pool().await;
3046    for (user_id, role) in connection_pool
3047        .channel_user_ids(root_id)
3048        .collect::<Vec<_>>()
3049        .into_iter()
3050    {
3051        let update = if role.can_see_channel(channel.visibility) {
3052            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3053            proto::UpdateChannels {
3054                channels: vec![channel.to_proto()],
3055                ..Default::default()
3056            }
3057        } else {
3058            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3059            proto::UpdateChannels {
3060                delete_channels: vec![channel.id.to_proto()],
3061                ..Default::default()
3062            }
3063        };
3064
3065        for connection_id in connection_pool.user_connection_ids(user_id) {
3066            session.peer.send(connection_id, update.clone())?;
3067        }
3068    }
3069
3070    response.send(proto::Ack {})?;
3071    Ok(())
3072}
3073
3074/// Alter the role for a user in the channel.
3075async fn set_channel_member_role(
3076    request: proto::SetChannelMemberRole,
3077    response: Response<proto::SetChannelMemberRole>,
3078    session: MessageContext,
3079) -> Result<()> {
3080    let db = session.db().await;
3081    let channel_id = ChannelId::from_proto(request.channel_id);
3082    let member_id = UserId::from_proto(request.user_id);
3083    let result = db
3084        .set_channel_member_role(
3085            channel_id,
3086            session.user_id(),
3087            member_id,
3088            request.role().into(),
3089        )
3090        .await?;
3091
3092    match result {
3093        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3094            let mut connection_pool = session.connection_pool().await;
3095            notify_membership_updated(
3096                &mut connection_pool,
3097                membership_update,
3098                member_id,
3099                &session.peer,
3100            )
3101        }
3102        db::SetMemberRoleResult::InviteUpdated(channel) => {
3103            let update = proto::UpdateChannels {
3104                channel_invitations: vec![channel.to_proto()],
3105                ..Default::default()
3106            };
3107
3108            for connection_id in session
3109                .connection_pool()
3110                .await
3111                .user_connection_ids(member_id)
3112            {
3113                session.peer.send(connection_id, update.clone())?;
3114            }
3115        }
3116    }
3117
3118    response.send(proto::Ack {})?;
3119    Ok(())
3120}
3121
3122/// Change the name of a channel
3123async fn rename_channel(
3124    request: proto::RenameChannel,
3125    response: Response<proto::RenameChannel>,
3126    session: MessageContext,
3127) -> Result<()> {
3128    let db = session.db().await;
3129    let channel_id = ChannelId::from_proto(request.channel_id);
3130    let channel_model = db
3131        .rename_channel(channel_id, session.user_id(), &request.name)
3132        .await?;
3133    let root_id = channel_model.root_id();
3134    let channel = Channel::from_model(channel_model);
3135
3136    response.send(proto::RenameChannelResponse {
3137        channel: Some(channel.to_proto()),
3138    })?;
3139
3140    let connection_pool = session.connection_pool().await;
3141    let update = proto::UpdateChannels {
3142        channels: vec![channel.to_proto()],
3143        ..Default::default()
3144    };
3145    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3146        if role.can_see_channel(channel.visibility) {
3147            session.peer.send(connection_id, update.clone())?;
3148        }
3149    }
3150
3151    Ok(())
3152}
3153
3154/// Move a channel to a new parent.
3155async fn move_channel(
3156    request: proto::MoveChannel,
3157    response: Response<proto::MoveChannel>,
3158    session: MessageContext,
3159) -> Result<()> {
3160    let channel_id = ChannelId::from_proto(request.channel_id);
3161    let to = ChannelId::from_proto(request.to);
3162
3163    let (root_id, channels) = session
3164        .db()
3165        .await
3166        .move_channel(channel_id, to, session.user_id())
3167        .await?;
3168
3169    let connection_pool = session.connection_pool().await;
3170    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3171        let channels = channels
3172            .iter()
3173            .filter_map(|channel| {
3174                if role.can_see_channel(channel.visibility) {
3175                    Some(channel.to_proto())
3176                } else {
3177                    None
3178                }
3179            })
3180            .collect::<Vec<_>>();
3181        if channels.is_empty() {
3182            continue;
3183        }
3184
3185        let update = proto::UpdateChannels {
3186            channels,
3187            ..Default::default()
3188        };
3189
3190        session.peer.send(connection_id, update.clone())?;
3191    }
3192
3193    response.send(Ack {})?;
3194    Ok(())
3195}
3196
3197async fn reorder_channel(
3198    request: proto::ReorderChannel,
3199    response: Response<proto::ReorderChannel>,
3200    session: MessageContext,
3201) -> Result<()> {
3202    let channel_id = ChannelId::from_proto(request.channel_id);
3203    let direction = request.direction();
3204
3205    let updated_channels = session
3206        .db()
3207        .await
3208        .reorder_channel(channel_id, direction, session.user_id())
3209        .await?;
3210
3211    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3212        let connection_pool = session.connection_pool().await;
3213        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3214            let channels = updated_channels
3215                .iter()
3216                .filter_map(|channel| {
3217                    if role.can_see_channel(channel.visibility) {
3218                        Some(channel.to_proto())
3219                    } else {
3220                        None
3221                    }
3222                })
3223                .collect::<Vec<_>>();
3224
3225            if channels.is_empty() {
3226                continue;
3227            }
3228
3229            let update = proto::UpdateChannels {
3230                channels,
3231                ..Default::default()
3232            };
3233
3234            session.peer.send(connection_id, update.clone())?;
3235        }
3236    }
3237
3238    response.send(Ack {})?;
3239    Ok(())
3240}
3241
3242/// Get the list of channel members
3243async fn get_channel_members(
3244    request: proto::GetChannelMembers,
3245    response: Response<proto::GetChannelMembers>,
3246    session: MessageContext,
3247) -> Result<()> {
3248    let db = session.db().await;
3249    let channel_id = ChannelId::from_proto(request.channel_id);
3250    let limit = if request.limit == 0 {
3251        u16::MAX as u64
3252    } else {
3253        request.limit
3254    };
3255    let (members, users) = db
3256        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3257        .await?;
3258    response.send(proto::GetChannelMembersResponse { members, users })?;
3259    Ok(())
3260}
3261
3262/// Accept or decline a channel invitation.
3263async fn respond_to_channel_invite(
3264    request: proto::RespondToChannelInvite,
3265    response: Response<proto::RespondToChannelInvite>,
3266    session: MessageContext,
3267) -> Result<()> {
3268    let db = session.db().await;
3269    let channel_id = ChannelId::from_proto(request.channel_id);
3270    let RespondToChannelInvite {
3271        membership_update,
3272        notifications,
3273    } = db
3274        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3275        .await?;
3276
3277    let mut connection_pool = session.connection_pool().await;
3278    if let Some(membership_update) = membership_update {
3279        notify_membership_updated(
3280            &mut connection_pool,
3281            membership_update,
3282            session.user_id(),
3283            &session.peer,
3284        );
3285    } else {
3286        let update = proto::UpdateChannels {
3287            remove_channel_invitations: vec![channel_id.to_proto()],
3288            ..Default::default()
3289        };
3290
3291        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3292            session.peer.send(connection_id, update.clone())?;
3293        }
3294    };
3295
3296    send_notifications(&connection_pool, &session.peer, notifications);
3297
3298    response.send(proto::Ack {})?;
3299
3300    Ok(())
3301}
3302
3303/// Join the channels' room
3304async fn join_channel(
3305    request: proto::JoinChannel,
3306    response: Response<proto::JoinChannel>,
3307    session: MessageContext,
3308) -> Result<()> {
3309    let channel_id = ChannelId::from_proto(request.channel_id);
3310    join_channel_internal(channel_id, Box::new(response), session).await
3311}
3312
3313trait JoinChannelInternalResponse {
3314    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3315}
3316impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3317    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3318        Response::<proto::JoinChannel>::send(self, result)
3319    }
3320}
3321impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3322    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3323        Response::<proto::JoinRoom>::send(self, result)
3324    }
3325}
3326
3327async fn join_channel_internal(
3328    channel_id: ChannelId,
3329    response: Box<impl JoinChannelInternalResponse>,
3330    session: MessageContext,
3331) -> Result<()> {
3332    let joined_room = {
3333        let mut db = session.db().await;
3334        // If zed quits without leaving the room, and the user re-opens zed before the
3335        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3336        // room they were in.
3337        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3338            tracing::info!(
3339                stale_connection_id = %connection,
3340                "cleaning up stale connection",
3341            );
3342            drop(db);
3343            leave_room_for_session(&session, connection).await?;
3344            db = session.db().await;
3345        }
3346
3347        let (joined_room, membership_updated, role) = db
3348            .join_channel(channel_id, session.user_id(), session.connection_id)
3349            .await?;
3350
3351        let live_kit_connection_info =
3352            session
3353                .app_state
3354                .livekit_client
3355                .as_ref()
3356                .and_then(|live_kit| {
3357                    let (can_publish, token) = if role == ChannelRole::Guest {
3358                        (
3359                            false,
3360                            live_kit
3361                                .guest_token(
3362                                    &joined_room.room.livekit_room,
3363                                    &session.user_id().to_string(),
3364                                )
3365                                .trace_err()?,
3366                        )
3367                    } else {
3368                        (
3369                            true,
3370                            live_kit
3371                                .room_token(
3372                                    &joined_room.room.livekit_room,
3373                                    &session.user_id().to_string(),
3374                                )
3375                                .trace_err()?,
3376                        )
3377                    };
3378
3379                    Some(LiveKitConnectionInfo {
3380                        server_url: live_kit.url().into(),
3381                        token,
3382                        can_publish,
3383                    })
3384                });
3385
3386        response.send(proto::JoinRoomResponse {
3387            room: Some(joined_room.room.clone()),
3388            channel_id: joined_room
3389                .channel
3390                .as_ref()
3391                .map(|channel| channel.id.to_proto()),
3392            live_kit_connection_info,
3393        })?;
3394
3395        let mut connection_pool = session.connection_pool().await;
3396        if let Some(membership_updated) = membership_updated {
3397            notify_membership_updated(
3398                &mut connection_pool,
3399                membership_updated,
3400                session.user_id(),
3401                &session.peer,
3402            );
3403        }
3404
3405        room_updated(&joined_room.room, &session.peer);
3406
3407        joined_room
3408    };
3409
3410    channel_updated(
3411        &joined_room.channel.context("channel not returned")?,
3412        &joined_room.room,
3413        &session.peer,
3414        &*session.connection_pool().await,
3415    );
3416
3417    update_user_contacts(session.user_id(), &session).await?;
3418    Ok(())
3419}
3420
3421/// Start editing the channel notes
3422async fn join_channel_buffer(
3423    request: proto::JoinChannelBuffer,
3424    response: Response<proto::JoinChannelBuffer>,
3425    session: MessageContext,
3426) -> Result<()> {
3427    let db = session.db().await;
3428    let channel_id = ChannelId::from_proto(request.channel_id);
3429
3430    let open_response = db
3431        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3432        .await?;
3433
3434    let collaborators = open_response.collaborators.clone();
3435    response.send(open_response)?;
3436
3437    let update = UpdateChannelBufferCollaborators {
3438        channel_id: channel_id.to_proto(),
3439        collaborators: collaborators.clone(),
3440    };
3441    channel_buffer_updated(
3442        session.connection_id,
3443        collaborators
3444            .iter()
3445            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3446        &update,
3447        &session.peer,
3448    );
3449
3450    Ok(())
3451}
3452
3453/// Edit the channel notes
3454async fn update_channel_buffer(
3455    request: proto::UpdateChannelBuffer,
3456    session: MessageContext,
3457) -> Result<()> {
3458    let db = session.db().await;
3459    let channel_id = ChannelId::from_proto(request.channel_id);
3460
3461    let (collaborators, epoch, version) = db
3462        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3463        .await?;
3464
3465    channel_buffer_updated(
3466        session.connection_id,
3467        collaborators.clone(),
3468        &proto::UpdateChannelBuffer {
3469            channel_id: channel_id.to_proto(),
3470            operations: request.operations,
3471        },
3472        &session.peer,
3473    );
3474
3475    let pool = &*session.connection_pool().await;
3476
3477    let non_collaborators =
3478        pool.channel_connection_ids(channel_id)
3479            .filter_map(|(connection_id, _)| {
3480                if collaborators.contains(&connection_id) {
3481                    None
3482                } else {
3483                    Some(connection_id)
3484                }
3485            });
3486
3487    broadcast(None, non_collaborators, |peer_id| {
3488        session.peer.send(
3489            peer_id,
3490            proto::UpdateChannels {
3491                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3492                    channel_id: channel_id.to_proto(),
3493                    epoch: epoch as u64,
3494                    version: version.clone(),
3495                }],
3496                ..Default::default()
3497            },
3498        )
3499    });
3500
3501    Ok(())
3502}
3503
3504/// Rejoin the channel notes after a connection blip
3505async fn rejoin_channel_buffers(
3506    request: proto::RejoinChannelBuffers,
3507    response: Response<proto::RejoinChannelBuffers>,
3508    session: MessageContext,
3509) -> Result<()> {
3510    let db = session.db().await;
3511    let buffers = db
3512        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3513        .await?;
3514
3515    for rejoined_buffer in &buffers {
3516        let collaborators_to_notify = rejoined_buffer
3517            .buffer
3518            .collaborators
3519            .iter()
3520            .filter_map(|c| Some(c.peer_id?.into()));
3521        channel_buffer_updated(
3522            session.connection_id,
3523            collaborators_to_notify,
3524            &proto::UpdateChannelBufferCollaborators {
3525                channel_id: rejoined_buffer.buffer.channel_id,
3526                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3527            },
3528            &session.peer,
3529        );
3530    }
3531
3532    response.send(proto::RejoinChannelBuffersResponse {
3533        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3534    })?;
3535
3536    Ok(())
3537}
3538
3539/// Stop editing the channel notes
3540async fn leave_channel_buffer(
3541    request: proto::LeaveChannelBuffer,
3542    response: Response<proto::LeaveChannelBuffer>,
3543    session: MessageContext,
3544) -> Result<()> {
3545    let db = session.db().await;
3546    let channel_id = ChannelId::from_proto(request.channel_id);
3547
3548    let left_buffer = db
3549        .leave_channel_buffer(channel_id, session.connection_id)
3550        .await?;
3551
3552    response.send(Ack {})?;
3553
3554    channel_buffer_updated(
3555        session.connection_id,
3556        left_buffer.connections,
3557        &proto::UpdateChannelBufferCollaborators {
3558            channel_id: channel_id.to_proto(),
3559            collaborators: left_buffer.collaborators,
3560        },
3561        &session.peer,
3562    );
3563
3564    Ok(())
3565}
3566
3567fn channel_buffer_updated<T: EnvelopedMessage>(
3568    sender_id: ConnectionId,
3569    collaborators: impl IntoIterator<Item = ConnectionId>,
3570    message: &T,
3571    peer: &Peer,
3572) {
3573    broadcast(Some(sender_id), collaborators, |peer_id| {
3574        peer.send(peer_id, message.clone())
3575    });
3576}
3577
3578fn send_notifications(
3579    connection_pool: &ConnectionPool,
3580    peer: &Peer,
3581    notifications: db::NotificationBatch,
3582) {
3583    for (user_id, notification) in notifications {
3584        for connection_id in connection_pool.user_connection_ids(user_id) {
3585            if let Err(error) = peer.send(
3586                connection_id,
3587                proto::AddNotification {
3588                    notification: Some(notification.clone()),
3589                },
3590            ) {
3591                tracing::error!(
3592                    "failed to send notification to {:?} {}",
3593                    connection_id,
3594                    error
3595                );
3596            }
3597        }
3598    }
3599}
3600
3601/// Send a message to the channel
3602async fn send_channel_message(
3603    _request: proto::SendChannelMessage,
3604    _response: Response<proto::SendChannelMessage>,
3605    _session: MessageContext,
3606) -> Result<()> {
3607    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3608}
3609
3610/// Delete a channel message
3611async fn remove_channel_message(
3612    _request: proto::RemoveChannelMessage,
3613    _response: Response<proto::RemoveChannelMessage>,
3614    _session: MessageContext,
3615) -> Result<()> {
3616    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3617}
3618
3619async fn update_channel_message(
3620    _request: proto::UpdateChannelMessage,
3621    _response: Response<proto::UpdateChannelMessage>,
3622    _session: MessageContext,
3623) -> Result<()> {
3624    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3625}
3626
3627/// Mark a channel message as read
3628async fn acknowledge_channel_message(
3629    _request: proto::AckChannelMessage,
3630    _session: MessageContext,
3631) -> Result<()> {
3632    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3633}
3634
3635/// Mark a buffer version as synced
3636async fn acknowledge_buffer_version(
3637    request: proto::AckBufferOperation,
3638    session: MessageContext,
3639) -> Result<()> {
3640    let buffer_id = BufferId::from_proto(request.buffer_id);
3641    session
3642        .db()
3643        .await
3644        .observe_buffer_version(
3645            buffer_id,
3646            session.user_id(),
3647            request.epoch as i32,
3648            &request.version,
3649        )
3650        .await?;
3651    Ok(())
3652}
3653
3654/// Get a Supermaven API key for the user
3655async fn get_supermaven_api_key(
3656    _request: proto::GetSupermavenApiKey,
3657    response: Response<proto::GetSupermavenApiKey>,
3658    session: MessageContext,
3659) -> Result<()> {
3660    let user_id: String = session.user_id().to_string();
3661    if !session.is_staff() {
3662        return Err(anyhow!("supermaven not enabled for this account"))?;
3663    }
3664
3665    let email = session.email().context("user must have an email")?;
3666
3667    let supermaven_admin_api = session
3668        .supermaven_client
3669        .as_ref()
3670        .context("supermaven not configured")?;
3671
3672    let result = supermaven_admin_api
3673        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3674        .await?;
3675
3676    response.send(proto::GetSupermavenApiKeyResponse {
3677        api_key: result.api_key,
3678    })?;
3679
3680    Ok(())
3681}
3682
3683/// Start receiving chat updates for a channel
3684async fn join_channel_chat(
3685    _request: proto::JoinChannelChat,
3686    _response: Response<proto::JoinChannelChat>,
3687    _session: MessageContext,
3688) -> Result<()> {
3689    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3690}
3691
3692/// Stop receiving chat updates for a channel
3693async fn leave_channel_chat(
3694    _request: proto::LeaveChannelChat,
3695    _session: MessageContext,
3696) -> Result<()> {
3697    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3698}
3699
3700/// Retrieve the chat history for a channel
3701async fn get_channel_messages(
3702    _request: proto::GetChannelMessages,
3703    _response: Response<proto::GetChannelMessages>,
3704    _session: MessageContext,
3705) -> Result<()> {
3706    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3707}
3708
3709/// Retrieve specific chat messages
3710async fn get_channel_messages_by_id(
3711    _request: proto::GetChannelMessagesById,
3712    _response: Response<proto::GetChannelMessagesById>,
3713    _session: MessageContext,
3714) -> Result<()> {
3715    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3716}
3717
3718/// Retrieve the current users notifications
3719async fn get_notifications(
3720    request: proto::GetNotifications,
3721    response: Response<proto::GetNotifications>,
3722    session: MessageContext,
3723) -> Result<()> {
3724    let notifications = session
3725        .db()
3726        .await
3727        .get_notifications(
3728            session.user_id(),
3729            NOTIFICATION_COUNT_PER_PAGE,
3730            request.before_id.map(db::NotificationId::from_proto),
3731        )
3732        .await?;
3733    response.send(proto::GetNotificationsResponse {
3734        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3735        notifications,
3736    })?;
3737    Ok(())
3738}
3739
3740/// Mark notifications as read
3741async fn mark_notification_as_read(
3742    request: proto::MarkNotificationRead,
3743    response: Response<proto::MarkNotificationRead>,
3744    session: MessageContext,
3745) -> Result<()> {
3746    let database = &session.db().await;
3747    let notifications = database
3748        .mark_notification_as_read_by_id(
3749            session.user_id(),
3750            NotificationId::from_proto(request.notification_id),
3751        )
3752        .await?;
3753    send_notifications(
3754        &*session.connection_pool().await,
3755        &session.peer,
3756        notifications,
3757    );
3758    response.send(proto::Ack {})?;
3759    Ok(())
3760}
3761
3762fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3763    let message = match message {
3764        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3765        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3766        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3767        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3768        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3769            code: frame.code.into(),
3770            reason: frame.reason.as_str().to_owned().into(),
3771        })),
3772        // We should never receive a frame while reading the message, according
3773        // to the `tungstenite` maintainers:
3774        //
3775        // > It cannot occur when you read messages from the WebSocket, but it
3776        // > can be used when you want to send the raw frames (e.g. you want to
3777        // > send the frames to the WebSocket without composing the full message first).
3778        // >
3779        // > — https://github.com/snapview/tungstenite-rs/issues/268
3780        TungsteniteMessage::Frame(_) => {
3781            bail!("received an unexpected frame while reading the message")
3782        }
3783    };
3784
3785    Ok(message)
3786}
3787
3788fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3789    match message {
3790        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3791        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3792        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3793        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3794        AxumMessage::Close(frame) => {
3795            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3796                code: frame.code.into(),
3797                reason: frame.reason.as_ref().into(),
3798            }))
3799        }
3800    }
3801}
3802
3803fn notify_membership_updated(
3804    connection_pool: &mut ConnectionPool,
3805    result: MembershipUpdated,
3806    user_id: UserId,
3807    peer: &Peer,
3808) {
3809    for membership in &result.new_channels.channel_memberships {
3810        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3811    }
3812    for channel_id in &result.removed_channels {
3813        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3814    }
3815
3816    let user_channels_update = proto::UpdateUserChannels {
3817        channel_memberships: result
3818            .new_channels
3819            .channel_memberships
3820            .iter()
3821            .map(|cm| proto::ChannelMembership {
3822                channel_id: cm.channel_id.to_proto(),
3823                role: cm.role.into(),
3824            })
3825            .collect(),
3826        ..Default::default()
3827    };
3828
3829    let mut update = build_channels_update(result.new_channels);
3830    update.delete_channels = result
3831        .removed_channels
3832        .into_iter()
3833        .map(|id| id.to_proto())
3834        .collect();
3835    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3836
3837    for connection_id in connection_pool.user_connection_ids(user_id) {
3838        peer.send(connection_id, user_channels_update.clone())
3839            .trace_err();
3840        peer.send(connection_id, update.clone()).trace_err();
3841    }
3842}
3843
3844fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3845    proto::UpdateUserChannels {
3846        channel_memberships: channels
3847            .channel_memberships
3848            .iter()
3849            .map(|m| proto::ChannelMembership {
3850                channel_id: m.channel_id.to_proto(),
3851                role: m.role.into(),
3852            })
3853            .collect(),
3854        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3855    }
3856}
3857
3858fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3859    let mut update = proto::UpdateChannels::default();
3860
3861    for channel in channels.channels {
3862        update.channels.push(channel.to_proto());
3863    }
3864
3865    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3866
3867    for (channel_id, participants) in channels.channel_participants {
3868        update
3869            .channel_participants
3870            .push(proto::ChannelParticipants {
3871                channel_id: channel_id.to_proto(),
3872                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3873            });
3874    }
3875
3876    for channel in channels.invited_channels {
3877        update.channel_invitations.push(channel.to_proto());
3878    }
3879
3880    update
3881}
3882
3883fn build_initial_contacts_update(
3884    contacts: Vec<db::Contact>,
3885    pool: &ConnectionPool,
3886) -> proto::UpdateContacts {
3887    let mut update = proto::UpdateContacts::default();
3888
3889    for contact in contacts {
3890        match contact {
3891            db::Contact::Accepted { user_id, busy } => {
3892                update.contacts.push(contact_for_user(user_id, busy, pool));
3893            }
3894            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3895            db::Contact::Incoming { user_id } => {
3896                update
3897                    .incoming_requests
3898                    .push(proto::IncomingContactRequest {
3899                        requester_id: user_id.to_proto(),
3900                    })
3901            }
3902        }
3903    }
3904
3905    update
3906}
3907
3908fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3909    proto::Contact {
3910        user_id: user_id.to_proto(),
3911        online: pool.is_user_online(user_id),
3912        busy,
3913    }
3914}
3915
3916fn room_updated(room: &proto::Room, peer: &Peer) {
3917    broadcast(
3918        None,
3919        room.participants
3920            .iter()
3921            .filter_map(|participant| Some(participant.peer_id?.into())),
3922        |peer_id| {
3923            peer.send(
3924                peer_id,
3925                proto::RoomUpdated {
3926                    room: Some(room.clone()),
3927                },
3928            )
3929        },
3930    );
3931}
3932
3933fn channel_updated(
3934    channel: &db::channel::Model,
3935    room: &proto::Room,
3936    peer: &Peer,
3937    pool: &ConnectionPool,
3938) {
3939    let participants = room
3940        .participants
3941        .iter()
3942        .map(|p| p.user_id)
3943        .collect::<Vec<_>>();
3944
3945    broadcast(
3946        None,
3947        pool.channel_connection_ids(channel.root_id())
3948            .filter_map(|(channel_id, role)| {
3949                role.can_see_channel(channel.visibility)
3950                    .then_some(channel_id)
3951            }),
3952        |peer_id| {
3953            peer.send(
3954                peer_id,
3955                proto::UpdateChannels {
3956                    channel_participants: vec![proto::ChannelParticipants {
3957                        channel_id: channel.id.to_proto(),
3958                        participant_user_ids: participants.clone(),
3959                    }],
3960                    ..Default::default()
3961                },
3962            )
3963        },
3964    );
3965}
3966
3967async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3968    let db = session.db().await;
3969
3970    let contacts = db.get_contacts(user_id).await?;
3971    let busy = db.is_user_busy(user_id).await?;
3972
3973    let pool = session.connection_pool().await;
3974    let updated_contact = contact_for_user(user_id, busy, &pool);
3975    for contact in contacts {
3976        if let db::Contact::Accepted {
3977            user_id: contact_user_id,
3978            ..
3979        } = contact
3980        {
3981            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3982                session
3983                    .peer
3984                    .send(
3985                        contact_conn_id,
3986                        proto::UpdateContacts {
3987                            contacts: vec![updated_contact.clone()],
3988                            remove_contacts: Default::default(),
3989                            incoming_requests: Default::default(),
3990                            remove_incoming_requests: Default::default(),
3991                            outgoing_requests: Default::default(),
3992                            remove_outgoing_requests: Default::default(),
3993                        },
3994                    )
3995                    .trace_err();
3996            }
3997        }
3998    }
3999    Ok(())
4000}
4001
4002async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4003    let mut contacts_to_update = HashSet::default();
4004
4005    let room_id;
4006    let canceled_calls_to_user_ids;
4007    let livekit_room;
4008    let delete_livekit_room;
4009    let room;
4010    let channel;
4011
4012    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4013        contacts_to_update.insert(session.user_id());
4014
4015        for project in left_room.left_projects.values() {
4016            project_left(project, session);
4017        }
4018
4019        room_id = RoomId::from_proto(left_room.room.id);
4020        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4021        livekit_room = mem::take(&mut left_room.room.livekit_room);
4022        delete_livekit_room = left_room.deleted;
4023        room = mem::take(&mut left_room.room);
4024        channel = mem::take(&mut left_room.channel);
4025
4026        room_updated(&room, &session.peer);
4027    } else {
4028        return Ok(());
4029    }
4030
4031    if let Some(channel) = channel {
4032        channel_updated(
4033            &channel,
4034            &room,
4035            &session.peer,
4036            &*session.connection_pool().await,
4037        );
4038    }
4039
4040    {
4041        let pool = session.connection_pool().await;
4042        for canceled_user_id in canceled_calls_to_user_ids {
4043            for connection_id in pool.user_connection_ids(canceled_user_id) {
4044                session
4045                    .peer
4046                    .send(
4047                        connection_id,
4048                        proto::CallCanceled {
4049                            room_id: room_id.to_proto(),
4050                        },
4051                    )
4052                    .trace_err();
4053            }
4054            contacts_to_update.insert(canceled_user_id);
4055        }
4056    }
4057
4058    for contact_user_id in contacts_to_update {
4059        update_user_contacts(contact_user_id, session).await?;
4060    }
4061
4062    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4063        live_kit
4064            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4065            .await
4066            .trace_err();
4067
4068        if delete_livekit_room {
4069            live_kit.delete_room(livekit_room).await.trace_err();
4070        }
4071    }
4072
4073    Ok(())
4074}
4075
4076async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4077    let left_channel_buffers = session
4078        .db()
4079        .await
4080        .leave_channel_buffers(session.connection_id)
4081        .await?;
4082
4083    for left_buffer in left_channel_buffers {
4084        channel_buffer_updated(
4085            session.connection_id,
4086            left_buffer.connections,
4087            &proto::UpdateChannelBufferCollaborators {
4088                channel_id: left_buffer.channel_id.to_proto(),
4089                collaborators: left_buffer.collaborators,
4090            },
4091            &session.peer,
4092        );
4093    }
4094
4095    Ok(())
4096}
4097
4098fn project_left(project: &db::LeftProject, session: &Session) {
4099    for connection_id in &project.connection_ids {
4100        if project.should_unshare {
4101            session
4102                .peer
4103                .send(
4104                    *connection_id,
4105                    proto::UnshareProject {
4106                        project_id: project.id.to_proto(),
4107                    },
4108                )
4109                .trace_err();
4110        } else {
4111            session
4112                .peer
4113                .send(
4114                    *connection_id,
4115                    proto::RemoveProjectCollaborator {
4116                        project_id: project.id.to_proto(),
4117                        peer_id: Some(session.connection_id.into()),
4118                    },
4119                )
4120                .trace_err();
4121        }
4122    }
4123}
4124
4125pub trait ResultExt {
4126    type Ok;
4127
4128    fn trace_err(self) -> Option<Self::Ok>;
4129}
4130
4131impl<T, E> ResultExt for Result<T, E>
4132where
4133    E: std::fmt::Debug,
4134{
4135    type Ok = T;
4136
4137    #[track_caller]
4138    fn trace_err(self) -> Option<T> {
4139        match self {
4140            Ok(value) => Some(value),
4141            Err(error) => {
4142                tracing::error!("{:?}", error);
4143                None
4144            }
4145        }
4146    }
4147}