rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
  10    },
  11    executor::Executor,
  12};
  13use anyhow::{Context as _, anyhow, bail};
  14use async_tungstenite::tungstenite::{
  15    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  16};
  17use axum::headers::UserAgent;
  18use axum::{
  19    Extension, Router, TypedHeader,
  20    body::Body,
  21    extract::{
  22        ConnectInfo, WebSocketUpgrade,
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use core::fmt::{self, Debug, Formatter};
  34use futures::TryFutureExt as _;
  35use reqwest_client::ReqwestClient;
  36use rpc::proto::split_repository_update;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38use tracing::Span;
  39use util::paths::PathStyle;
  40
  41use futures::{
  42    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  43    stream::FuturesUnordered,
  44};
  45use prometheus::{IntGauge, register_int_gauge};
  46use rpc::{
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52};
  53use semantic_version::SemanticVersion;
  54use serde::{Serialize, Serializer};
  55use std::{
  56    any::TypeId,
  57    future::Future,
  58    marker::PhantomData,
  59    mem,
  60    net::SocketAddr,
  61    ops::{Deref, DerefMut},
  62    rc::Rc,
  63    sync::{
  64        Arc, OnceLock,
  65        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  66    },
  67    time::{Duration, Instant},
  68};
  69use tokio::sync::{Semaphore, watch};
  70use tower::ServiceBuilder;
  71use tracing::{
  72    Instrument,
  73    field::{self},
  74    info_span, instrument,
  75};
  76
  77pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  78
  79// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  80pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  81
  82const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  83const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  84
  85static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  86
  87const TOTAL_DURATION_MS: &str = "total_duration_ms";
  88const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  89const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  90const HOST_WAITING_MS: &str = "host_waiting_ms";
  91
  92type MessageHandler =
  93    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  94
  95pub struct ConnectionGuard;
  96
  97impl ConnectionGuard {
  98    pub fn try_acquire() -> Result<Self, ()> {
  99        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 100        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 101            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 102            tracing::error!(
 103                "too many concurrent connections: {}",
 104                current_connections + 1
 105            );
 106            return Err(());
 107        }
 108        Ok(ConnectionGuard)
 109    }
 110}
 111
 112impl Drop for ConnectionGuard {
 113    fn drop(&mut self) {
 114        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 115    }
 116}
 117
 118struct Response<R> {
 119    peer: Arc<Peer>,
 120    receipt: Receipt<R>,
 121    responded: Arc<AtomicBool>,
 122}
 123
 124impl<R: RequestMessage> Response<R> {
 125    fn send(self, payload: R::Response) -> Result<()> {
 126        self.responded.store(true, SeqCst);
 127        self.peer.respond(self.receipt, payload)?;
 128        Ok(())
 129    }
 130}
 131
 132#[derive(Clone, Debug)]
 133pub enum Principal {
 134    User(User),
 135    Impersonated { user: User, admin: User },
 136}
 137
 138impl Principal {
 139    fn update_span(&self, span: &tracing::Span) {
 140        match &self {
 141            Principal::User(user) => {
 142                span.record("user_id", user.id.0);
 143                span.record("login", &user.github_login);
 144            }
 145            Principal::Impersonated { user, admin } => {
 146                span.record("user_id", user.id.0);
 147                span.record("login", &user.github_login);
 148                span.record("impersonator", &admin.github_login);
 149            }
 150        }
 151    }
 152}
 153
 154#[derive(Clone)]
 155struct MessageContext {
 156    session: Session,
 157    span: tracing::Span,
 158}
 159
 160impl Deref for MessageContext {
 161    type Target = Session;
 162
 163    fn deref(&self) -> &Self::Target {
 164        &self.session
 165    }
 166}
 167
 168impl MessageContext {
 169    pub fn forward_request<T: RequestMessage>(
 170        &self,
 171        receiver_id: ConnectionId,
 172        request: T,
 173    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 174        let request_start_time = Instant::now();
 175        let span = self.span.clone();
 176        tracing::info!("start forwarding request");
 177        self.peer
 178            .forward_request(self.connection_id, receiver_id, request)
 179            .inspect(move |_| {
 180                span.record(
 181                    HOST_WAITING_MS,
 182                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 183                );
 184            })
 185            .inspect_err(|_| tracing::error!("error forwarding request"))
 186            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 187    }
 188}
 189
 190#[derive(Clone)]
 191struct Session {
 192    principal: Principal,
 193    connection_id: ConnectionId,
 194    db: Arc<tokio::sync::Mutex<DbHandle>>,
 195    peer: Arc<Peer>,
 196    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 197    app_state: Arc<AppState>,
 198    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 199    /// The GeoIP country code for the user.
 200    #[allow(unused)]
 201    geoip_country_code: Option<String>,
 202    #[allow(unused)]
 203    system_id: Option<String>,
 204    _executor: Executor,
 205}
 206
 207impl Session {
 208    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 209        #[cfg(test)]
 210        tokio::task::yield_now().await;
 211        let guard = self.db.lock().await;
 212        #[cfg(test)]
 213        tokio::task::yield_now().await;
 214        guard
 215    }
 216
 217    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 218        #[cfg(test)]
 219        tokio::task::yield_now().await;
 220        let guard = self.connection_pool.lock();
 221        ConnectionPoolGuard {
 222            guard,
 223            _not_send: PhantomData,
 224        }
 225    }
 226
 227    const fn is_staff(&self) -> bool {
 228        match &self.principal {
 229            Principal::User(user) => user.admin,
 230            Principal::Impersonated { .. } => true,
 231        }
 232    }
 233
 234    const fn user_id(&self) -> UserId {
 235        match &self.principal {
 236            Principal::User(user) => user.id,
 237            Principal::Impersonated { user, .. } => user.id,
 238        }
 239    }
 240
 241    pub fn email(&self) -> Option<String> {
 242        match &self.principal {
 243            Principal::User(user) => user.email_address.clone(),
 244            Principal::Impersonated { user, .. } => user.email_address.clone(),
 245        }
 246    }
 247}
 248
 249impl Debug for Session {
 250    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 251        let mut result = f.debug_struct("Session");
 252        match &self.principal {
 253            Principal::User(user) => {
 254                result.field("user", &user.github_login);
 255            }
 256            Principal::Impersonated { user, admin } => {
 257                result.field("user", &user.github_login);
 258                result.field("impersonator", &admin.github_login);
 259            }
 260        }
 261        result.field("connection_id", &self.connection_id).finish()
 262    }
 263}
 264
 265struct DbHandle(Arc<Database>);
 266
 267impl Deref for DbHandle {
 268    type Target = Database;
 269
 270    fn deref(&self) -> &Self::Target {
 271        self.0.as_ref()
 272    }
 273}
 274
 275pub struct Server {
 276    id: parking_lot::Mutex<ServerId>,
 277    peer: Arc<Peer>,
 278    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 279    app_state: Arc<AppState>,
 280    handlers: HashMap<TypeId, MessageHandler>,
 281    teardown: watch::Sender<bool>,
 282}
 283
 284pub(crate) struct ConnectionPoolGuard<'a> {
 285    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 286    _not_send: PhantomData<Rc<()>>,
 287}
 288
 289#[derive(Serialize)]
 290pub struct ServerSnapshot<'a> {
 291    peer: &'a Peer,
 292    #[serde(serialize_with = "serialize_deref")]
 293    connection_pool: ConnectionPoolGuard<'a>,
 294}
 295
 296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 297where
 298    S: Serializer,
 299    T: Deref<Target = U>,
 300    U: Serialize,
 301{
 302    Serialize::serialize(value.deref(), serializer)
 303}
 304
 305impl Server {
 306    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 307        let mut server = Self {
 308            id: parking_lot::Mutex::new(id),
 309            peer: Peer::new(id.0 as u32),
 310            app_state,
 311            connection_pool: Default::default(),
 312            handlers: Default::default(),
 313            teardown: watch::channel(false).0,
 314        };
 315
 316        server
 317            .add_request_handler(ping)
 318            .add_request_handler(create_room)
 319            .add_request_handler(join_room)
 320            .add_request_handler(rejoin_room)
 321            .add_request_handler(leave_room)
 322            .add_request_handler(set_room_participant_role)
 323            .add_request_handler(call)
 324            .add_request_handler(cancel_call)
 325            .add_message_handler(decline_call)
 326            .add_request_handler(update_participant_location)
 327            .add_request_handler(share_project)
 328            .add_message_handler(unshare_project)
 329            .add_request_handler(join_project)
 330            .add_message_handler(leave_project)
 331            .add_request_handler(update_project)
 332            .add_request_handler(update_worktree)
 333            .add_request_handler(update_repository)
 334            .add_request_handler(remove_repository)
 335            .add_message_handler(start_language_server)
 336            .add_message_handler(update_language_server)
 337            .add_message_handler(update_diagnostic_summary)
 338            .add_message_handler(update_worktree_settings)
 339            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 340            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 341            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 342            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 343            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 344            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 345            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 346            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 347            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 348            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 349            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 350            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 351            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 352            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 353            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 354            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 355            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 356            .add_request_handler(
 357                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 358            )
 359            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 360            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 361            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 362            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 363            .add_request_handler(
 364                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 365            )
 366            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 367            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 368            .add_request_handler(
 369                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 370            )
 371            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 372            .add_request_handler(
 373                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 374            )
 375            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 376            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 377            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 378            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 379            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 380            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 381            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 382            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 383            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 384            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 385            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 386            .add_request_handler(
 387                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 388            )
 389            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 390            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 391            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 392            .add_request_handler(lsp_query)
 393            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 394            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 395            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 396            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 397            .add_message_handler(create_buffer_for_peer)
 398            .add_request_handler(update_buffer)
 399            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 400            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 401            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 402            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 403            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 404            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 405            .add_message_handler(
 406                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 407            )
 408            .add_request_handler(get_users)
 409            .add_request_handler(fuzzy_search_users)
 410            .add_request_handler(request_contact)
 411            .add_request_handler(remove_contact)
 412            .add_request_handler(respond_to_contact_request)
 413            .add_message_handler(subscribe_to_channels)
 414            .add_request_handler(create_channel)
 415            .add_request_handler(delete_channel)
 416            .add_request_handler(invite_channel_member)
 417            .add_request_handler(remove_channel_member)
 418            .add_request_handler(set_channel_member_role)
 419            .add_request_handler(set_channel_visibility)
 420            .add_request_handler(rename_channel)
 421            .add_request_handler(join_channel_buffer)
 422            .add_request_handler(leave_channel_buffer)
 423            .add_message_handler(update_channel_buffer)
 424            .add_request_handler(rejoin_channel_buffers)
 425            .add_request_handler(get_channel_members)
 426            .add_request_handler(respond_to_channel_invite)
 427            .add_request_handler(join_channel)
 428            .add_request_handler(join_channel_chat)
 429            .add_message_handler(leave_channel_chat)
 430            .add_request_handler(send_channel_message)
 431            .add_request_handler(remove_channel_message)
 432            .add_request_handler(update_channel_message)
 433            .add_request_handler(get_channel_messages)
 434            .add_request_handler(get_channel_messages_by_id)
 435            .add_request_handler(get_notifications)
 436            .add_request_handler(mark_notification_as_read)
 437            .add_request_handler(move_channel)
 438            .add_request_handler(reorder_channel)
 439            .add_request_handler(follow)
 440            .add_message_handler(unfollow)
 441            .add_message_handler(update_followers)
 442            .add_message_handler(acknowledge_channel_message)
 443            .add_message_handler(acknowledge_buffer_version)
 444            .add_request_handler(get_supermaven_api_key)
 445            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 446            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 447            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 448            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 449            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 450            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 451            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 452            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 453            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 454            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 455            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 456            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 457            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 458            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 459            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 460            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 461            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 462            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 463            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 464            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 465            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 466            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 467            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 468            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 469            .add_message_handler(update_context)
 470            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 471            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
 472
 473        Arc::new(server)
 474    }
 475
 476    pub async fn start(&self) -> Result<()> {
 477        let server_id = *self.id.lock();
 478        let app_state = self.app_state.clone();
 479        let peer = self.peer.clone();
 480        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 481        let pool = self.connection_pool.clone();
 482        let livekit_client = self.app_state.livekit_client.clone();
 483
 484        let span = info_span!("start server");
 485        self.app_state.executor.spawn_detached(
 486            async move {
 487                tracing::info!("waiting for cleanup timeout");
 488                timeout.await;
 489                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 490
 491                app_state
 492                    .db
 493                    .delete_stale_channel_chat_participants(
 494                        &app_state.config.zed_environment,
 495                        server_id,
 496                    )
 497                    .await
 498                    .trace_err();
 499
 500                if let Some((room_ids, channel_ids)) = app_state
 501                    .db
 502                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 503                    .await
 504                    .trace_err()
 505                {
 506                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 507                    tracing::info!(
 508                        stale_channel_buffer_count = channel_ids.len(),
 509                        "retrieved stale channel buffers"
 510                    );
 511
 512                    for channel_id in channel_ids {
 513                        if let Some(refreshed_channel_buffer) = app_state
 514                            .db
 515                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 516                            .await
 517                            .trace_err()
 518                        {
 519                            for connection_id in refreshed_channel_buffer.connection_ids {
 520                                peer.send(
 521                                    connection_id,
 522                                    proto::UpdateChannelBufferCollaborators {
 523                                        channel_id: channel_id.to_proto(),
 524                                        collaborators: refreshed_channel_buffer
 525                                            .collaborators
 526                                            .clone(),
 527                                    },
 528                                )
 529                                .trace_err();
 530                            }
 531                        }
 532                    }
 533
 534                    for room_id in room_ids {
 535                        let mut contacts_to_update = HashSet::default();
 536                        let mut canceled_calls_to_user_ids = Vec::new();
 537                        let mut livekit_room = String::new();
 538                        let mut delete_livekit_room = false;
 539
 540                        if let Some(mut refreshed_room) = app_state
 541                            .db
 542                            .clear_stale_room_participants(room_id, server_id)
 543                            .await
 544                            .trace_err()
 545                        {
 546                            tracing::info!(
 547                                room_id = room_id.0,
 548                                new_participant_count = refreshed_room.room.participants.len(),
 549                                "refreshed room"
 550                            );
 551                            room_updated(&refreshed_room.room, &peer);
 552                            if let Some(channel) = refreshed_room.channel.as_ref() {
 553                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 554                            }
 555                            contacts_to_update
 556                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 557                            contacts_to_update
 558                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 559                            canceled_calls_to_user_ids =
 560                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 561                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 562                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 563                        }
 564
 565                        {
 566                            let pool = pool.lock();
 567                            for canceled_user_id in canceled_calls_to_user_ids {
 568                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 569                                    peer.send(
 570                                        connection_id,
 571                                        proto::CallCanceled {
 572                                            room_id: room_id.to_proto(),
 573                                        },
 574                                    )
 575                                    .trace_err();
 576                                }
 577                            }
 578                        }
 579
 580                        for user_id in contacts_to_update {
 581                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 582                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 583                            if let Some((busy, contacts)) = busy.zip(contacts) {
 584                                let pool = pool.lock();
 585                                let updated_contact = contact_for_user(user_id, busy, &pool);
 586                                for contact in contacts {
 587                                    if let db::Contact::Accepted {
 588                                        user_id: contact_user_id,
 589                                        ..
 590                                    } = contact
 591                                    {
 592                                        for contact_conn_id in
 593                                            pool.user_connection_ids(contact_user_id)
 594                                        {
 595                                            peer.send(
 596                                                contact_conn_id,
 597                                                proto::UpdateContacts {
 598                                                    contacts: vec![updated_contact.clone()],
 599                                                    remove_contacts: Default::default(),
 600                                                    incoming_requests: Default::default(),
 601                                                    remove_incoming_requests: Default::default(),
 602                                                    outgoing_requests: Default::default(),
 603                                                    remove_outgoing_requests: Default::default(),
 604                                                },
 605                                            )
 606                                            .trace_err();
 607                                        }
 608                                    }
 609                                }
 610                            }
 611                        }
 612
 613                        if let Some(live_kit) = livekit_client.as_ref()
 614                            && delete_livekit_room
 615                        {
 616                            live_kit.delete_room(livekit_room).await.trace_err();
 617                        }
 618                    }
 619                }
 620
 621                app_state
 622                    .db
 623                    .delete_stale_channel_chat_participants(
 624                        &app_state.config.zed_environment,
 625                        server_id,
 626                    )
 627                    .await
 628                    .trace_err();
 629
 630                app_state
 631                    .db
 632                    .clear_old_worktree_entries(server_id)
 633                    .await
 634                    .trace_err();
 635
 636                app_state
 637                    .db
 638                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 639                    .await
 640                    .trace_err();
 641            }
 642            .instrument(span),
 643        );
 644        Ok(())
 645    }
 646
 647    pub fn teardown(&self) {
 648        self.peer.teardown();
 649        self.connection_pool.lock().reset();
 650        let _ = self.teardown.send(true);
 651    }
 652
 653    #[cfg(test)]
 654    pub fn reset(&self, id: ServerId) {
 655        self.teardown();
 656        *self.id.lock() = id;
 657        self.peer.reset(id.0 as u32);
 658        let _ = self.teardown.send(false);
 659    }
 660
 661    #[cfg(test)]
 662    pub fn id(&self) -> ServerId {
 663        *self.id.lock()
 664    }
 665
 666    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 667    where
 668        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 669        Fut: 'static + Send + Future<Output = Result<()>>,
 670        M: EnvelopedMessage,
 671    {
 672        let prev_handler = self.handlers.insert(
 673            TypeId::of::<M>(),
 674            Box::new(move |envelope, session, span| {
 675                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 676                let received_at = envelope.received_at;
 677                tracing::info!("message received");
 678                let start_time = Instant::now();
 679                let future = (handler)(
 680                    *envelope,
 681                    MessageContext {
 682                        session,
 683                        span: span.clone(),
 684                    },
 685                );
 686                async move {
 687                    let result = future.await;
 688                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 689                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 690                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 691                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 692                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 693                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 694                    match result {
 695                        Err(error) => {
 696                            tracing::error!(?error, "error handling message")
 697                        }
 698                        Ok(()) => tracing::info!("finished handling message"),
 699                    }
 700                }
 701                .boxed()
 702            }),
 703        );
 704        if prev_handler.is_some() {
 705            panic!("registered a handler for the same message twice");
 706        }
 707        self
 708    }
 709
 710    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 711    where
 712        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 713        Fut: 'static + Send + Future<Output = Result<()>>,
 714        M: EnvelopedMessage,
 715    {
 716        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 717        self
 718    }
 719
 720    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 721    where
 722        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 723        Fut: Send + Future<Output = Result<()>>,
 724        M: RequestMessage,
 725    {
 726        let handler = Arc::new(handler);
 727        self.add_handler(move |envelope, session| {
 728            let receipt = envelope.receipt();
 729            let handler = handler.clone();
 730            async move {
 731                let peer = session.peer.clone();
 732                let responded = Arc::new(AtomicBool::default());
 733                let response = Response {
 734                    peer: peer.clone(),
 735                    responded: responded.clone(),
 736                    receipt,
 737                };
 738                match (handler)(envelope.payload, response, session).await {
 739                    Ok(()) => {
 740                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 741                            Ok(())
 742                        } else {
 743                            Err(anyhow!("handler did not send a response"))?
 744                        }
 745                    }
 746                    Err(error) => {
 747                        let proto_err = match &error {
 748                            Error::Internal(err) => err.to_proto(),
 749                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 750                        };
 751                        peer.respond_with_error(receipt, proto_err)?;
 752                        Err(error)
 753                    }
 754                }
 755            }
 756        })
 757    }
 758
 759    pub fn handle_connection(
 760        self: &Arc<Self>,
 761        connection: Connection,
 762        address: String,
 763        principal: Principal,
 764        zed_version: ZedVersion,
 765        release_channel: Option<String>,
 766        user_agent: Option<String>,
 767        geoip_country_code: Option<String>,
 768        system_id: Option<String>,
 769        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 770        executor: Executor,
 771        connection_guard: Option<ConnectionGuard>,
 772    ) -> impl Future<Output = ()> + use<> {
 773        let this = self.clone();
 774        let span = info_span!("handle connection", %address,
 775            connection_id=field::Empty,
 776            user_id=field::Empty,
 777            login=field::Empty,
 778            impersonator=field::Empty,
 779            user_agent=field::Empty,
 780            geoip_country_code=field::Empty,
 781            release_channel=field::Empty,
 782        );
 783        principal.update_span(&span);
 784        if let Some(user_agent) = user_agent {
 785            span.record("user_agent", user_agent);
 786        }
 787        if let Some(release_channel) = release_channel {
 788            span.record("release_channel", release_channel);
 789        }
 790
 791        if let Some(country_code) = geoip_country_code.as_ref() {
 792            span.record("geoip_country_code", country_code);
 793        }
 794
 795        let mut teardown = self.teardown.subscribe();
 796        async move {
 797            if *teardown.borrow() {
 798                tracing::error!("server is tearing down");
 799                return;
 800            }
 801
 802            let (connection_id, handle_io, mut incoming_rx) =
 803                this.peer.add_connection(connection, {
 804                    let executor = executor.clone();
 805                    move |duration| executor.sleep(duration)
 806                });
 807            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 808
 809            tracing::info!("connection opened");
 810
 811            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 812            let http_client = match ReqwestClient::user_agent(&user_agent) {
 813                Ok(http_client) => Arc::new(http_client),
 814                Err(error) => {
 815                    tracing::error!(?error, "failed to create HTTP client");
 816                    return;
 817                }
 818            };
 819
 820            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 821                |supermaven_admin_api_key| {
 822                    Arc::new(SupermavenAdminApi::new(
 823                        supermaven_admin_api_key.to_string(),
 824                        http_client.clone(),
 825                    ))
 826                },
 827            );
 828
 829            let session = Session {
 830                principal: principal.clone(),
 831                connection_id,
 832                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 833                peer: this.peer.clone(),
 834                connection_pool: this.connection_pool.clone(),
 835                app_state: this.app_state.clone(),
 836                geoip_country_code,
 837                system_id,
 838                _executor: executor.clone(),
 839                supermaven_client,
 840            };
 841
 842            if let Err(error) = this
 843                .send_initial_client_update(
 844                    connection_id,
 845                    zed_version,
 846                    send_connection_id,
 847                    &session,
 848                )
 849                .await
 850            {
 851                tracing::error!(?error, "failed to send initial client update");
 852                return;
 853            }
 854            drop(connection_guard);
 855
 856            let handle_io = handle_io.fuse();
 857            futures::pin_mut!(handle_io);
 858
 859            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 860            // This prevents deadlocks when e.g., client A performs a request to client B and
 861            // client B performs a request to client A. If both clients stop processing further
 862            // messages until their respective request completes, they won't have a chance to
 863            // respond to the other client's request and cause a deadlock.
 864            //
 865            // This arrangement ensures we will attempt to process earlier messages first, but fall
 866            // back to processing messages arrived later in the spirit of making progress.
 867            const MAX_CONCURRENT_HANDLERS: usize = 256;
 868            let mut foreground_message_handlers = FuturesUnordered::new();
 869            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 870            let get_concurrent_handlers = {
 871                let concurrent_handlers = concurrent_handlers.clone();
 872                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 873            };
 874            loop {
 875                let next_message = async {
 876                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 877                    let message = incoming_rx.next().await;
 878                    // Cache the concurrent_handlers here, so that we know what the
 879                    // queue looks like as each handler starts
 880                    (permit, message, get_concurrent_handlers())
 881                }
 882                .fuse();
 883                futures::pin_mut!(next_message);
 884                futures::select_biased! {
 885                    _ = teardown.changed().fuse() => return,
 886                    result = handle_io => {
 887                        if let Err(error) = result {
 888                            tracing::error!(?error, "error handling I/O");
 889                        }
 890                        break;
 891                    }
 892                    _ = foreground_message_handlers.next() => {}
 893                    next_message = next_message => {
 894                        let (permit, message, concurrent_handlers) = next_message;
 895                        if let Some(message) = message {
 896                            let type_name = message.payload_type_name();
 897                            // note: we copy all the fields from the parent span so we can query them in the logs.
 898                            // (https://github.com/tokio-rs/tracing/issues/2670).
 899                            let span = tracing::info_span!("receive message",
 900                                %connection_id,
 901                                %address,
 902                                type_name,
 903                                concurrent_handlers,
 904                                user_id=field::Empty,
 905                                login=field::Empty,
 906                                impersonator=field::Empty,
 907                                lsp_query_request=field::Empty,
 908                                release_channel=field::Empty,
 909                                { TOTAL_DURATION_MS }=field::Empty,
 910                                { PROCESSING_DURATION_MS }=field::Empty,
 911                                { QUEUE_DURATION_MS }=field::Empty,
 912                                { HOST_WAITING_MS }=field::Empty
 913                            );
 914                            principal.update_span(&span);
 915                            let span_enter = span.enter();
 916                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 917                                let is_background = message.is_background();
 918                                let handle_message = (handler)(message, session.clone(), span.clone());
 919                                drop(span_enter);
 920
 921                                let handle_message = async move {
 922                                    handle_message.await;
 923                                    drop(permit);
 924                                }.instrument(span);
 925                                if is_background {
 926                                    executor.spawn_detached(handle_message);
 927                                } else {
 928                                    foreground_message_handlers.push(handle_message);
 929                                }
 930                            } else {
 931                                tracing::error!("no message handler");
 932                            }
 933                        } else {
 934                            tracing::info!("connection closed");
 935                            break;
 936                        }
 937                    }
 938                }
 939            }
 940
 941            drop(foreground_message_handlers);
 942            let concurrent_handlers = get_concurrent_handlers();
 943            tracing::info!(concurrent_handlers, "signing out");
 944            if let Err(error) = connection_lost(session, teardown, executor).await {
 945                tracing::error!(?error, "error signing out");
 946            }
 947        }
 948        .instrument(span)
 949    }
 950
 951    async fn send_initial_client_update(
 952        &self,
 953        connection_id: ConnectionId,
 954        zed_version: ZedVersion,
 955        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 956        session: &Session,
 957    ) -> Result<()> {
 958        self.peer.send(
 959            connection_id,
 960            proto::Hello {
 961                peer_id: Some(connection_id.into()),
 962            },
 963        )?;
 964        tracing::info!("sent hello message");
 965        if let Some(send_connection_id) = send_connection_id.take() {
 966            let _ = send_connection_id.send(connection_id);
 967        }
 968
 969        match &session.principal {
 970            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 971                if !user.connected_once {
 972                    self.peer.send(connection_id, proto::ShowContacts {})?;
 973                    self.app_state
 974                        .db
 975                        .set_user_connected_once(user.id, true)
 976                        .await?;
 977                }
 978
 979                let contacts = self.app_state.db.get_contacts(user.id).await?;
 980
 981                {
 982                    let mut pool = self.connection_pool.lock();
 983                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 984                    self.peer.send(
 985                        connection_id,
 986                        build_initial_contacts_update(contacts, &pool),
 987                    )?;
 988                }
 989
 990                if should_auto_subscribe_to_channels(zed_version) {
 991                    subscribe_user_to_channels(user.id, session).await?;
 992                }
 993
 994                if let Some(incoming_call) =
 995                    self.app_state.db.incoming_call_for_user(user.id).await?
 996                {
 997                    self.peer.send(connection_id, incoming_call)?;
 998                }
 999
1000                update_user_contacts(user.id, session).await?;
1001            }
1002        }
1003
1004        Ok(())
1005    }
1006
1007    pub async fn invite_code_redeemed(
1008        self: &Arc<Self>,
1009        inviter_id: UserId,
1010        invitee_id: UserId,
1011    ) -> Result<()> {
1012        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1013            && let Some(code) = &user.invite_code
1014        {
1015            let pool = self.connection_pool.lock();
1016            let invitee_contact = contact_for_user(invitee_id, false, &pool);
1017            for connection_id in pool.user_connection_ids(inviter_id) {
1018                self.peer.send(
1019                    connection_id,
1020                    proto::UpdateContacts {
1021                        contacts: vec![invitee_contact.clone()],
1022                        ..Default::default()
1023                    },
1024                )?;
1025                self.peer.send(
1026                    connection_id,
1027                    proto::UpdateInviteInfo {
1028                        url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1029                        count: user.invite_count as u32,
1030                    },
1031                )?;
1032            }
1033        }
1034        Ok(())
1035    }
1036
1037    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1038        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1039            && let Some(invite_code) = &user.invite_code
1040        {
1041            let pool = self.connection_pool.lock();
1042            for connection_id in pool.user_connection_ids(user_id) {
1043                self.peer.send(
1044                    connection_id,
1045                    proto::UpdateInviteInfo {
1046                        url: format!(
1047                            "{}{}",
1048                            self.app_state.config.invite_link_prefix, invite_code
1049                        ),
1050                        count: user.invite_count as u32,
1051                    },
1052                )?;
1053            }
1054        }
1055        Ok(())
1056    }
1057
1058    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1059        ServerSnapshot {
1060            connection_pool: ConnectionPoolGuard {
1061                guard: self.connection_pool.lock(),
1062                _not_send: PhantomData,
1063            },
1064            peer: &self.peer,
1065        }
1066    }
1067}
1068
1069impl Deref for ConnectionPoolGuard<'_> {
1070    type Target = ConnectionPool;
1071
1072    fn deref(&self) -> &Self::Target {
1073        &self.guard
1074    }
1075}
1076
1077impl DerefMut for ConnectionPoolGuard<'_> {
1078    fn deref_mut(&mut self) -> &mut Self::Target {
1079        &mut self.guard
1080    }
1081}
1082
1083impl Drop for ConnectionPoolGuard<'_> {
1084    fn drop(&mut self) {
1085        #[cfg(test)]
1086        self.check_invariants();
1087    }
1088}
1089
1090fn broadcast<F>(
1091    sender_id: Option<ConnectionId>,
1092    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1093    mut f: F,
1094) where
1095    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1096{
1097    for receiver_id in receiver_ids {
1098        if Some(receiver_id) != sender_id
1099            && let Err(error) = f(receiver_id)
1100        {
1101            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1102        }
1103    }
1104}
1105
1106pub struct ProtocolVersion(u32);
1107
1108impl Header for ProtocolVersion {
1109    fn name() -> &'static HeaderName {
1110        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1111        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1112    }
1113
1114    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1115    where
1116        Self: Sized,
1117        I: Iterator<Item = &'i axum::http::HeaderValue>,
1118    {
1119        let version = values
1120            .next()
1121            .ok_or_else(axum::headers::Error::invalid)?
1122            .to_str()
1123            .map_err(|_| axum::headers::Error::invalid())?
1124            .parse()
1125            .map_err(|_| axum::headers::Error::invalid())?;
1126        Ok(Self(version))
1127    }
1128
1129    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1130        values.extend([self.0.to_string().parse().unwrap()]);
1131    }
1132}
1133
1134pub struct AppVersionHeader(SemanticVersion);
1135impl Header for AppVersionHeader {
1136    fn name() -> &'static HeaderName {
1137        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1138        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1139    }
1140
1141    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1142    where
1143        Self: Sized,
1144        I: Iterator<Item = &'i axum::http::HeaderValue>,
1145    {
1146        let version = values
1147            .next()
1148            .ok_or_else(axum::headers::Error::invalid)?
1149            .to_str()
1150            .map_err(|_| axum::headers::Error::invalid())?
1151            .parse()
1152            .map_err(|_| axum::headers::Error::invalid())?;
1153        Ok(Self(version))
1154    }
1155
1156    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1157        values.extend([self.0.to_string().parse().unwrap()]);
1158    }
1159}
1160
1161#[derive(Debug)]
1162pub struct ReleaseChannelHeader(String);
1163
1164impl Header for ReleaseChannelHeader {
1165    fn name() -> &'static HeaderName {
1166        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1167        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1168    }
1169
1170    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1171    where
1172        Self: Sized,
1173        I: Iterator<Item = &'i axum::http::HeaderValue>,
1174    {
1175        Ok(Self(
1176            values
1177                .next()
1178                .ok_or_else(axum::headers::Error::invalid)?
1179                .to_str()
1180                .map_err(|_| axum::headers::Error::invalid())?
1181                .to_owned(),
1182        ))
1183    }
1184
1185    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1186        values.extend([self.0.parse().unwrap()]);
1187    }
1188}
1189
1190pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1191    Router::new()
1192        .route("/rpc", get(handle_websocket_request))
1193        .layer(
1194            ServiceBuilder::new()
1195                .layer(Extension(server.app_state.clone()))
1196                .layer(middleware::from_fn(auth::validate_header)),
1197        )
1198        .route("/metrics", get(handle_metrics))
1199        .layer(Extension(server))
1200}
1201
1202pub async fn handle_websocket_request(
1203    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1204    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1205    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1206    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1207    Extension(server): Extension<Arc<Server>>,
1208    Extension(principal): Extension<Principal>,
1209    user_agent: Option<TypedHeader<UserAgent>>,
1210    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1211    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1212    ws: WebSocketUpgrade,
1213) -> axum::response::Response {
1214    if protocol_version != rpc::PROTOCOL_VERSION {
1215        return (
1216            StatusCode::UPGRADE_REQUIRED,
1217            "client must be upgraded".to_string(),
1218        )
1219            .into_response();
1220    }
1221
1222    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1223        return (
1224            StatusCode::UPGRADE_REQUIRED,
1225            "no version header found".to_string(),
1226        )
1227            .into_response();
1228    };
1229
1230    let release_channel = release_channel_header.map(|header| header.0.0);
1231
1232    if !version.can_collaborate() {
1233        return (
1234            StatusCode::UPGRADE_REQUIRED,
1235            "client must be upgraded".to_string(),
1236        )
1237            .into_response();
1238    }
1239
1240    let socket_address = socket_address.to_string();
1241
1242    // Acquire connection guard before WebSocket upgrade
1243    let connection_guard = match ConnectionGuard::try_acquire() {
1244        Ok(guard) => guard,
1245        Err(()) => {
1246            return (
1247                StatusCode::SERVICE_UNAVAILABLE,
1248                "Too many concurrent connections",
1249            )
1250                .into_response();
1251        }
1252    };
1253
1254    ws.on_upgrade(move |socket| {
1255        let socket = socket
1256            .map_ok(to_tungstenite_message)
1257            .err_into()
1258            .with(|message| async move { to_axum_message(message) });
1259        let connection = Connection::new(Box::pin(socket));
1260        async move {
1261            server
1262                .handle_connection(
1263                    connection,
1264                    socket_address,
1265                    principal,
1266                    version,
1267                    release_channel,
1268                    user_agent.map(|header| header.to_string()),
1269                    country_code_header.map(|header| header.to_string()),
1270                    system_id_header.map(|header| header.to_string()),
1271                    None,
1272                    Executor::Production,
1273                    Some(connection_guard),
1274                )
1275                .await;
1276        }
1277    })
1278}
1279
1280pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1281    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1282    let connections_metric = CONNECTIONS_METRIC
1283        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1284
1285    let connections = server
1286        .connection_pool
1287        .lock()
1288        .connections()
1289        .filter(|connection| !connection.admin)
1290        .count();
1291    connections_metric.set(connections as _);
1292
1293    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1294    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1295        register_int_gauge!(
1296            "shared_projects",
1297            "number of open projects with one or more guests"
1298        )
1299        .unwrap()
1300    });
1301
1302    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1303    shared_projects_metric.set(shared_projects as _);
1304
1305    let encoder = prometheus::TextEncoder::new();
1306    let metric_families = prometheus::gather();
1307    let encoded_metrics = encoder
1308        .encode_to_string(&metric_families)
1309        .map_err(|err| anyhow!("{err}"))?;
1310    Ok(encoded_metrics)
1311}
1312
1313#[instrument(err, skip(executor))]
1314async fn connection_lost(
1315    session: Session,
1316    mut teardown: watch::Receiver<bool>,
1317    executor: Executor,
1318) -> Result<()> {
1319    session.peer.disconnect(session.connection_id);
1320    session
1321        .connection_pool()
1322        .await
1323        .remove_connection(session.connection_id)?;
1324
1325    session
1326        .db()
1327        .await
1328        .connection_lost(session.connection_id)
1329        .await
1330        .trace_err();
1331
1332    futures::select_biased! {
1333        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1334
1335            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1336            leave_room_for_session(&session, session.connection_id).await.trace_err();
1337            leave_channel_buffers_for_session(&session)
1338                .await
1339                .trace_err();
1340
1341            if !session
1342                .connection_pool()
1343                .await
1344                .is_user_online(session.user_id())
1345            {
1346                let db = session.db().await;
1347                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1348                    room_updated(&room, &session.peer);
1349                }
1350            }
1351
1352            update_user_contacts(session.user_id(), &session).await?;
1353        },
1354        _ = teardown.changed().fuse() => {}
1355    }
1356
1357    Ok(())
1358}
1359
1360/// Acknowledges a ping from a client, used to keep the connection alive.
1361async fn ping(
1362    _: proto::Ping,
1363    response: Response<proto::Ping>,
1364    _session: MessageContext,
1365) -> Result<()> {
1366    response.send(proto::Ack {})?;
1367    Ok(())
1368}
1369
1370/// Creates a new room for calling (outside of channels)
1371async fn create_room(
1372    _request: proto::CreateRoom,
1373    response: Response<proto::CreateRoom>,
1374    session: MessageContext,
1375) -> Result<()> {
1376    let livekit_room = nanoid::nanoid!(30);
1377
1378    let live_kit_connection_info = util::maybe!(async {
1379        let live_kit = session.app_state.livekit_client.as_ref();
1380        let live_kit = live_kit?;
1381        let user_id = session.user_id().to_string();
1382
1383        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1384
1385        Some(proto::LiveKitConnectionInfo {
1386            server_url: live_kit.url().into(),
1387            token,
1388            can_publish: true,
1389        })
1390    })
1391    .await;
1392
1393    let room = session
1394        .db()
1395        .await
1396        .create_room(session.user_id(), session.connection_id, &livekit_room)
1397        .await?;
1398
1399    response.send(proto::CreateRoomResponse {
1400        room: Some(room.clone()),
1401        live_kit_connection_info,
1402    })?;
1403
1404    update_user_contacts(session.user_id(), &session).await?;
1405    Ok(())
1406}
1407
1408/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1409async fn join_room(
1410    request: proto::JoinRoom,
1411    response: Response<proto::JoinRoom>,
1412    session: MessageContext,
1413) -> Result<()> {
1414    let room_id = RoomId::from_proto(request.id);
1415
1416    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1417
1418    if let Some(channel_id) = channel_id {
1419        return join_channel_internal(channel_id, Box::new(response), session).await;
1420    }
1421
1422    let joined_room = {
1423        let room = session
1424            .db()
1425            .await
1426            .join_room(room_id, session.user_id(), session.connection_id)
1427            .await?;
1428        room_updated(&room.room, &session.peer);
1429        room.into_inner()
1430    };
1431
1432    for connection_id in session
1433        .connection_pool()
1434        .await
1435        .user_connection_ids(session.user_id())
1436    {
1437        session
1438            .peer
1439            .send(
1440                connection_id,
1441                proto::CallCanceled {
1442                    room_id: room_id.to_proto(),
1443                },
1444            )
1445            .trace_err();
1446    }
1447
1448    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1449    {
1450        live_kit
1451            .room_token(
1452                &joined_room.room.livekit_room,
1453                &session.user_id().to_string(),
1454            )
1455            .trace_err()
1456            .map(|token| proto::LiveKitConnectionInfo {
1457                server_url: live_kit.url().into(),
1458                token,
1459                can_publish: true,
1460            })
1461    } else {
1462        None
1463    };
1464
1465    response.send(proto::JoinRoomResponse {
1466        room: Some(joined_room.room),
1467        channel_id: None,
1468        live_kit_connection_info,
1469    })?;
1470
1471    update_user_contacts(session.user_id(), &session).await?;
1472    Ok(())
1473}
1474
1475/// Rejoin room is used to reconnect to a room after connection errors.
1476async fn rejoin_room(
1477    request: proto::RejoinRoom,
1478    response: Response<proto::RejoinRoom>,
1479    session: MessageContext,
1480) -> Result<()> {
1481    let room;
1482    let channel;
1483    {
1484        let mut rejoined_room = session
1485            .db()
1486            .await
1487            .rejoin_room(request, session.user_id(), session.connection_id)
1488            .await?;
1489
1490        response.send(proto::RejoinRoomResponse {
1491            room: Some(rejoined_room.room.clone()),
1492            reshared_projects: rejoined_room
1493                .reshared_projects
1494                .iter()
1495                .map(|project| proto::ResharedProject {
1496                    id: project.id.to_proto(),
1497                    collaborators: project
1498                        .collaborators
1499                        .iter()
1500                        .map(|collaborator| collaborator.to_proto())
1501                        .collect(),
1502                })
1503                .collect(),
1504            rejoined_projects: rejoined_room
1505                .rejoined_projects
1506                .iter()
1507                .map(|rejoined_project| rejoined_project.to_proto())
1508                .collect(),
1509        })?;
1510        room_updated(&rejoined_room.room, &session.peer);
1511
1512        for project in &rejoined_room.reshared_projects {
1513            for collaborator in &project.collaborators {
1514                session
1515                    .peer
1516                    .send(
1517                        collaborator.connection_id,
1518                        proto::UpdateProjectCollaborator {
1519                            project_id: project.id.to_proto(),
1520                            old_peer_id: Some(project.old_connection_id.into()),
1521                            new_peer_id: Some(session.connection_id.into()),
1522                        },
1523                    )
1524                    .trace_err();
1525            }
1526
1527            broadcast(
1528                Some(session.connection_id),
1529                project
1530                    .collaborators
1531                    .iter()
1532                    .map(|collaborator| collaborator.connection_id),
1533                |connection_id| {
1534                    session.peer.forward_send(
1535                        session.connection_id,
1536                        connection_id,
1537                        proto::UpdateProject {
1538                            project_id: project.id.to_proto(),
1539                            worktrees: project.worktrees.clone(),
1540                        },
1541                    )
1542                },
1543            );
1544        }
1545
1546        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1547
1548        let rejoined_room = rejoined_room.into_inner();
1549
1550        room = rejoined_room.room;
1551        channel = rejoined_room.channel;
1552    }
1553
1554    if let Some(channel) = channel {
1555        channel_updated(
1556            &channel,
1557            &room,
1558            &session.peer,
1559            &*session.connection_pool().await,
1560        );
1561    }
1562
1563    update_user_contacts(session.user_id(), &session).await?;
1564    Ok(())
1565}
1566
1567fn notify_rejoined_projects(
1568    rejoined_projects: &mut Vec<RejoinedProject>,
1569    session: &Session,
1570) -> Result<()> {
1571    for project in rejoined_projects.iter() {
1572        for collaborator in &project.collaborators {
1573            session
1574                .peer
1575                .send(
1576                    collaborator.connection_id,
1577                    proto::UpdateProjectCollaborator {
1578                        project_id: project.id.to_proto(),
1579                        old_peer_id: Some(project.old_connection_id.into()),
1580                        new_peer_id: Some(session.connection_id.into()),
1581                    },
1582                )
1583                .trace_err();
1584        }
1585    }
1586
1587    for project in rejoined_projects {
1588        for worktree in mem::take(&mut project.worktrees) {
1589            // Stream this worktree's entries.
1590            let message = proto::UpdateWorktree {
1591                project_id: project.id.to_proto(),
1592                worktree_id: worktree.id,
1593                abs_path: worktree.abs_path.clone(),
1594                root_name: worktree.root_name,
1595                updated_entries: worktree.updated_entries,
1596                removed_entries: worktree.removed_entries,
1597                scan_id: worktree.scan_id,
1598                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1599                updated_repositories: worktree.updated_repositories,
1600                removed_repositories: worktree.removed_repositories,
1601            };
1602            for update in proto::split_worktree_update(message) {
1603                session.peer.send(session.connection_id, update)?;
1604            }
1605
1606            // Stream this worktree's diagnostics.
1607            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1608            if let Some(summary) = worktree_diagnostics.next() {
1609                let message = proto::UpdateDiagnosticSummary {
1610                    project_id: project.id.to_proto(),
1611                    worktree_id: worktree.id,
1612                    summary: Some(summary),
1613                    more_summaries: worktree_diagnostics.collect(),
1614                };
1615                session.peer.send(session.connection_id, message)?;
1616            }
1617
1618            for settings_file in worktree.settings_files {
1619                session.peer.send(
1620                    session.connection_id,
1621                    proto::UpdateWorktreeSettings {
1622                        project_id: project.id.to_proto(),
1623                        worktree_id: worktree.id,
1624                        path: settings_file.path,
1625                        content: Some(settings_file.content),
1626                        kind: Some(settings_file.kind.to_proto().into()),
1627                    },
1628                )?;
1629            }
1630        }
1631
1632        for repository in mem::take(&mut project.updated_repositories) {
1633            for update in split_repository_update(repository) {
1634                session.peer.send(session.connection_id, update)?;
1635            }
1636        }
1637
1638        for id in mem::take(&mut project.removed_repositories) {
1639            session.peer.send(
1640                session.connection_id,
1641                proto::RemoveRepository {
1642                    project_id: project.id.to_proto(),
1643                    id,
1644                },
1645            )?;
1646        }
1647    }
1648
1649    Ok(())
1650}
1651
1652/// leave room disconnects from the room.
1653async fn leave_room(
1654    _: proto::LeaveRoom,
1655    response: Response<proto::LeaveRoom>,
1656    session: MessageContext,
1657) -> Result<()> {
1658    leave_room_for_session(&session, session.connection_id).await?;
1659    response.send(proto::Ack {})?;
1660    Ok(())
1661}
1662
1663/// Updates the permissions of someone else in the room.
1664async fn set_room_participant_role(
1665    request: proto::SetRoomParticipantRole,
1666    response: Response<proto::SetRoomParticipantRole>,
1667    session: MessageContext,
1668) -> Result<()> {
1669    let user_id = UserId::from_proto(request.user_id);
1670    let role = ChannelRole::from(request.role());
1671
1672    let (livekit_room, can_publish) = {
1673        let room = session
1674            .db()
1675            .await
1676            .set_room_participant_role(
1677                session.user_id(),
1678                RoomId::from_proto(request.room_id),
1679                user_id,
1680                role,
1681            )
1682            .await?;
1683
1684        let livekit_room = room.livekit_room.clone();
1685        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1686        room_updated(&room, &session.peer);
1687        (livekit_room, can_publish)
1688    };
1689
1690    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1691        live_kit
1692            .update_participant(
1693                livekit_room.clone(),
1694                request.user_id.to_string(),
1695                livekit_api::proto::ParticipantPermission {
1696                    can_subscribe: true,
1697                    can_publish,
1698                    can_publish_data: can_publish,
1699                    hidden: false,
1700                    recorder: false,
1701                },
1702            )
1703            .await
1704            .trace_err();
1705    }
1706
1707    response.send(proto::Ack {})?;
1708    Ok(())
1709}
1710
1711/// Call someone else into the current room
1712async fn call(
1713    request: proto::Call,
1714    response: Response<proto::Call>,
1715    session: MessageContext,
1716) -> Result<()> {
1717    let room_id = RoomId::from_proto(request.room_id);
1718    let calling_user_id = session.user_id();
1719    let calling_connection_id = session.connection_id;
1720    let called_user_id = UserId::from_proto(request.called_user_id);
1721    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1722    if !session
1723        .db()
1724        .await
1725        .has_contact(calling_user_id, called_user_id)
1726        .await?
1727    {
1728        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1729    }
1730
1731    let incoming_call = {
1732        let (room, incoming_call) = &mut *session
1733            .db()
1734            .await
1735            .call(
1736                room_id,
1737                calling_user_id,
1738                calling_connection_id,
1739                called_user_id,
1740                initial_project_id,
1741            )
1742            .await?;
1743        room_updated(room, &session.peer);
1744        mem::take(incoming_call)
1745    };
1746    update_user_contacts(called_user_id, &session).await?;
1747
1748    let mut calls = session
1749        .connection_pool()
1750        .await
1751        .user_connection_ids(called_user_id)
1752        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1753        .collect::<FuturesUnordered<_>>();
1754
1755    while let Some(call_response) = calls.next().await {
1756        match call_response.as_ref() {
1757            Ok(_) => {
1758                response.send(proto::Ack {})?;
1759                return Ok(());
1760            }
1761            Err(_) => {
1762                call_response.trace_err();
1763            }
1764        }
1765    }
1766
1767    {
1768        let room = session
1769            .db()
1770            .await
1771            .call_failed(room_id, called_user_id)
1772            .await?;
1773        room_updated(&room, &session.peer);
1774    }
1775    update_user_contacts(called_user_id, &session).await?;
1776
1777    Err(anyhow!("failed to ring user"))?
1778}
1779
1780/// Cancel an outgoing call.
1781async fn cancel_call(
1782    request: proto::CancelCall,
1783    response: Response<proto::CancelCall>,
1784    session: MessageContext,
1785) -> Result<()> {
1786    let called_user_id = UserId::from_proto(request.called_user_id);
1787    let room_id = RoomId::from_proto(request.room_id);
1788    {
1789        let room = session
1790            .db()
1791            .await
1792            .cancel_call(room_id, session.connection_id, called_user_id)
1793            .await?;
1794        room_updated(&room, &session.peer);
1795    }
1796
1797    for connection_id in session
1798        .connection_pool()
1799        .await
1800        .user_connection_ids(called_user_id)
1801    {
1802        session
1803            .peer
1804            .send(
1805                connection_id,
1806                proto::CallCanceled {
1807                    room_id: room_id.to_proto(),
1808                },
1809            )
1810            .trace_err();
1811    }
1812    response.send(proto::Ack {})?;
1813
1814    update_user_contacts(called_user_id, &session).await?;
1815    Ok(())
1816}
1817
1818/// Decline an incoming call.
1819async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1820    let room_id = RoomId::from_proto(message.room_id);
1821    {
1822        let room = session
1823            .db()
1824            .await
1825            .decline_call(Some(room_id), session.user_id())
1826            .await?
1827            .context("declining call")?;
1828        room_updated(&room, &session.peer);
1829    }
1830
1831    for connection_id in session
1832        .connection_pool()
1833        .await
1834        .user_connection_ids(session.user_id())
1835    {
1836        session
1837            .peer
1838            .send(
1839                connection_id,
1840                proto::CallCanceled {
1841                    room_id: room_id.to_proto(),
1842                },
1843            )
1844            .trace_err();
1845    }
1846    update_user_contacts(session.user_id(), &session).await?;
1847    Ok(())
1848}
1849
1850/// Updates other participants in the room with your current location.
1851async fn update_participant_location(
1852    request: proto::UpdateParticipantLocation,
1853    response: Response<proto::UpdateParticipantLocation>,
1854    session: MessageContext,
1855) -> Result<()> {
1856    let room_id = RoomId::from_proto(request.room_id);
1857    let location = request.location.context("invalid location")?;
1858
1859    let db = session.db().await;
1860    let room = db
1861        .update_room_participant_location(room_id, session.connection_id, location)
1862        .await?;
1863
1864    room_updated(&room, &session.peer);
1865    response.send(proto::Ack {})?;
1866    Ok(())
1867}
1868
1869/// Share a project into the room.
1870async fn share_project(
1871    request: proto::ShareProject,
1872    response: Response<proto::ShareProject>,
1873    session: MessageContext,
1874) -> Result<()> {
1875    let (project_id, room) = &*session
1876        .db()
1877        .await
1878        .share_project(
1879            RoomId::from_proto(request.room_id),
1880            session.connection_id,
1881            &request.worktrees,
1882            request.is_ssh_project,
1883            request.windows_paths.unwrap_or(false),
1884        )
1885        .await?;
1886    response.send(proto::ShareProjectResponse {
1887        project_id: project_id.to_proto(),
1888    })?;
1889    room_updated(room, &session.peer);
1890
1891    Ok(())
1892}
1893
1894/// Unshare a project from the room.
1895async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1896    let project_id = ProjectId::from_proto(message.project_id);
1897    unshare_project_internal(project_id, session.connection_id, &session).await
1898}
1899
1900async fn unshare_project_internal(
1901    project_id: ProjectId,
1902    connection_id: ConnectionId,
1903    session: &Session,
1904) -> Result<()> {
1905    let delete = {
1906        let room_guard = session
1907            .db()
1908            .await
1909            .unshare_project(project_id, connection_id)
1910            .await?;
1911
1912        let (delete, room, guest_connection_ids) = &*room_guard;
1913
1914        let message = proto::UnshareProject {
1915            project_id: project_id.to_proto(),
1916        };
1917
1918        broadcast(
1919            Some(connection_id),
1920            guest_connection_ids.iter().copied(),
1921            |conn_id| session.peer.send(conn_id, message.clone()),
1922        );
1923        if let Some(room) = room {
1924            room_updated(room, &session.peer);
1925        }
1926
1927        *delete
1928    };
1929
1930    if delete {
1931        let db = session.db().await;
1932        db.delete_project(project_id).await?;
1933    }
1934
1935    Ok(())
1936}
1937
1938/// Join someone elses shared project.
1939async fn join_project(
1940    request: proto::JoinProject,
1941    response: Response<proto::JoinProject>,
1942    session: MessageContext,
1943) -> Result<()> {
1944    let project_id = ProjectId::from_proto(request.project_id);
1945
1946    tracing::info!(%project_id, "join project");
1947
1948    let db = session.db().await;
1949    let (project, replica_id) = &mut *db
1950        .join_project(
1951            project_id,
1952            session.connection_id,
1953            session.user_id(),
1954            request.committer_name.clone(),
1955            request.committer_email.clone(),
1956        )
1957        .await?;
1958    drop(db);
1959    tracing::info!(%project_id, "join remote project");
1960    let collaborators = project
1961        .collaborators
1962        .iter()
1963        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1964        .map(|collaborator| collaborator.to_proto())
1965        .collect::<Vec<_>>();
1966    let project_id = project.id;
1967    let guest_user_id = session.user_id();
1968
1969    let worktrees = project
1970        .worktrees
1971        .iter()
1972        .map(|(id, worktree)| proto::WorktreeMetadata {
1973            id: *id,
1974            root_name: worktree.root_name.clone(),
1975            visible: worktree.visible,
1976            abs_path: worktree.abs_path.clone(),
1977        })
1978        .collect::<Vec<_>>();
1979
1980    let add_project_collaborator = proto::AddProjectCollaborator {
1981        project_id: project_id.to_proto(),
1982        collaborator: Some(proto::Collaborator {
1983            peer_id: Some(session.connection_id.into()),
1984            replica_id: replica_id.0 as u32,
1985            user_id: guest_user_id.to_proto(),
1986            is_host: false,
1987            committer_name: request.committer_name.clone(),
1988            committer_email: request.committer_email.clone(),
1989        }),
1990    };
1991
1992    for collaborator in &collaborators {
1993        session
1994            .peer
1995            .send(
1996                collaborator.peer_id.unwrap().into(),
1997                add_project_collaborator.clone(),
1998            )
1999            .trace_err();
2000    }
2001
2002    // First, we send the metadata associated with each worktree.
2003    let (language_servers, language_server_capabilities) = project
2004        .language_servers
2005        .clone()
2006        .into_iter()
2007        .map(|server| (server.server, server.capabilities))
2008        .unzip();
2009    response.send(proto::JoinProjectResponse {
2010        project_id: project.id.0 as u64,
2011        worktrees,
2012        replica_id: replica_id.0 as u32,
2013        collaborators,
2014        language_servers,
2015        language_server_capabilities,
2016        role: project.role.into(),
2017        windows_paths: project.path_style == PathStyle::Windows,
2018    })?;
2019
2020    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2021        // Stream this worktree's entries.
2022        let message = proto::UpdateWorktree {
2023            project_id: project_id.to_proto(),
2024            worktree_id,
2025            abs_path: worktree.abs_path.clone(),
2026            root_name: worktree.root_name,
2027            updated_entries: worktree.entries,
2028            removed_entries: Default::default(),
2029            scan_id: worktree.scan_id,
2030            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2031            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2032            removed_repositories: Default::default(),
2033        };
2034        for update in proto::split_worktree_update(message) {
2035            session.peer.send(session.connection_id, update.clone())?;
2036        }
2037
2038        // Stream this worktree's diagnostics.
2039        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2040        if let Some(summary) = worktree_diagnostics.next() {
2041            let message = proto::UpdateDiagnosticSummary {
2042                project_id: project.id.to_proto(),
2043                worktree_id: worktree.id,
2044                summary: Some(summary),
2045                more_summaries: worktree_diagnostics.collect(),
2046            };
2047            session.peer.send(session.connection_id, message)?;
2048        }
2049
2050        for settings_file in worktree.settings_files {
2051            session.peer.send(
2052                session.connection_id,
2053                proto::UpdateWorktreeSettings {
2054                    project_id: project_id.to_proto(),
2055                    worktree_id: worktree.id,
2056                    path: settings_file.path,
2057                    content: Some(settings_file.content),
2058                    kind: Some(settings_file.kind.to_proto() as i32),
2059                },
2060            )?;
2061        }
2062    }
2063
2064    for repository in mem::take(&mut project.repositories) {
2065        for update in split_repository_update(repository) {
2066            session.peer.send(session.connection_id, update)?;
2067        }
2068    }
2069
2070    for language_server in &project.language_servers {
2071        session.peer.send(
2072            session.connection_id,
2073            proto::UpdateLanguageServer {
2074                project_id: project_id.to_proto(),
2075                server_name: Some(language_server.server.name.clone()),
2076                language_server_id: language_server.server.id,
2077                variant: Some(
2078                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2079                        proto::LspDiskBasedDiagnosticsUpdated {},
2080                    ),
2081                ),
2082            },
2083        )?;
2084    }
2085
2086    Ok(())
2087}
2088
2089/// Leave someone elses shared project.
2090async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2091    let sender_id = session.connection_id;
2092    let project_id = ProjectId::from_proto(request.project_id);
2093    let db = session.db().await;
2094
2095    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2096    tracing::info!(
2097        %project_id,
2098        "leave project"
2099    );
2100
2101    project_left(project, &session);
2102    if let Some(room) = room {
2103        room_updated(room, &session.peer);
2104    }
2105
2106    Ok(())
2107}
2108
2109/// Updates other participants with changes to the project
2110async fn update_project(
2111    request: proto::UpdateProject,
2112    response: Response<proto::UpdateProject>,
2113    session: MessageContext,
2114) -> Result<()> {
2115    let project_id = ProjectId::from_proto(request.project_id);
2116    let (room, guest_connection_ids) = &*session
2117        .db()
2118        .await
2119        .update_project(project_id, session.connection_id, &request.worktrees)
2120        .await?;
2121    broadcast(
2122        Some(session.connection_id),
2123        guest_connection_ids.iter().copied(),
2124        |connection_id| {
2125            session
2126                .peer
2127                .forward_send(session.connection_id, connection_id, request.clone())
2128        },
2129    );
2130    if let Some(room) = room {
2131        room_updated(room, &session.peer);
2132    }
2133    response.send(proto::Ack {})?;
2134
2135    Ok(())
2136}
2137
2138/// Updates other participants with changes to the worktree
2139async fn update_worktree(
2140    request: proto::UpdateWorktree,
2141    response: Response<proto::UpdateWorktree>,
2142    session: MessageContext,
2143) -> Result<()> {
2144    let guest_connection_ids = session
2145        .db()
2146        .await
2147        .update_worktree(&request, session.connection_id)
2148        .await?;
2149
2150    broadcast(
2151        Some(session.connection_id),
2152        guest_connection_ids.iter().copied(),
2153        |connection_id| {
2154            session
2155                .peer
2156                .forward_send(session.connection_id, connection_id, request.clone())
2157        },
2158    );
2159    response.send(proto::Ack {})?;
2160    Ok(())
2161}
2162
2163async fn update_repository(
2164    request: proto::UpdateRepository,
2165    response: Response<proto::UpdateRepository>,
2166    session: MessageContext,
2167) -> Result<()> {
2168    let guest_connection_ids = session
2169        .db()
2170        .await
2171        .update_repository(&request, session.connection_id)
2172        .await?;
2173
2174    broadcast(
2175        Some(session.connection_id),
2176        guest_connection_ids.iter().copied(),
2177        |connection_id| {
2178            session
2179                .peer
2180                .forward_send(session.connection_id, connection_id, request.clone())
2181        },
2182    );
2183    response.send(proto::Ack {})?;
2184    Ok(())
2185}
2186
2187async fn remove_repository(
2188    request: proto::RemoveRepository,
2189    response: Response<proto::RemoveRepository>,
2190    session: MessageContext,
2191) -> Result<()> {
2192    let guest_connection_ids = session
2193        .db()
2194        .await
2195        .remove_repository(&request, session.connection_id)
2196        .await?;
2197
2198    broadcast(
2199        Some(session.connection_id),
2200        guest_connection_ids.iter().copied(),
2201        |connection_id| {
2202            session
2203                .peer
2204                .forward_send(session.connection_id, connection_id, request.clone())
2205        },
2206    );
2207    response.send(proto::Ack {})?;
2208    Ok(())
2209}
2210
2211/// Updates other participants with changes to the diagnostics
2212async fn update_diagnostic_summary(
2213    message: proto::UpdateDiagnosticSummary,
2214    session: MessageContext,
2215) -> Result<()> {
2216    let guest_connection_ids = session
2217        .db()
2218        .await
2219        .update_diagnostic_summary(&message, session.connection_id)
2220        .await?;
2221
2222    broadcast(
2223        Some(session.connection_id),
2224        guest_connection_ids.iter().copied(),
2225        |connection_id| {
2226            session
2227                .peer
2228                .forward_send(session.connection_id, connection_id, message.clone())
2229        },
2230    );
2231
2232    Ok(())
2233}
2234
2235/// Updates other participants with changes to the worktree settings
2236async fn update_worktree_settings(
2237    message: proto::UpdateWorktreeSettings,
2238    session: MessageContext,
2239) -> Result<()> {
2240    let guest_connection_ids = session
2241        .db()
2242        .await
2243        .update_worktree_settings(&message, session.connection_id)
2244        .await?;
2245
2246    broadcast(
2247        Some(session.connection_id),
2248        guest_connection_ids.iter().copied(),
2249        |connection_id| {
2250            session
2251                .peer
2252                .forward_send(session.connection_id, connection_id, message.clone())
2253        },
2254    );
2255
2256    Ok(())
2257}
2258
2259/// Notify other participants that a language server has started.
2260async fn start_language_server(
2261    request: proto::StartLanguageServer,
2262    session: MessageContext,
2263) -> Result<()> {
2264    let guest_connection_ids = session
2265        .db()
2266        .await
2267        .start_language_server(&request, session.connection_id)
2268        .await?;
2269
2270    broadcast(
2271        Some(session.connection_id),
2272        guest_connection_ids.iter().copied(),
2273        |connection_id| {
2274            session
2275                .peer
2276                .forward_send(session.connection_id, connection_id, request.clone())
2277        },
2278    );
2279    Ok(())
2280}
2281
2282/// Notify other participants that a language server has changed.
2283async fn update_language_server(
2284    request: proto::UpdateLanguageServer,
2285    session: MessageContext,
2286) -> Result<()> {
2287    let project_id = ProjectId::from_proto(request.project_id);
2288    let db = session.db().await;
2289
2290    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2291        && let Some(capabilities) = update.capabilities.clone()
2292    {
2293        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2294            .await?;
2295    }
2296
2297    let project_connection_ids = db
2298        .project_connection_ids(project_id, session.connection_id, true)
2299        .await?;
2300    broadcast(
2301        Some(session.connection_id),
2302        project_connection_ids.iter().copied(),
2303        |connection_id| {
2304            session
2305                .peer
2306                .forward_send(session.connection_id, connection_id, request.clone())
2307        },
2308    );
2309    Ok(())
2310}
2311
2312/// forward a project request to the host. These requests should be read only
2313/// as guests are allowed to send them.
2314async fn forward_read_only_project_request<T>(
2315    request: T,
2316    response: Response<T>,
2317    session: MessageContext,
2318) -> Result<()>
2319where
2320    T: EntityMessage + RequestMessage,
2321{
2322    let project_id = ProjectId::from_proto(request.remote_entity_id());
2323    let host_connection_id = session
2324        .db()
2325        .await
2326        .host_for_read_only_project_request(project_id, session.connection_id)
2327        .await?;
2328    let payload = session.forward_request(host_connection_id, request).await?;
2329    response.send(payload)?;
2330    Ok(())
2331}
2332
2333/// forward a project request to the host. These requests are disallowed
2334/// for guests.
2335async fn forward_mutating_project_request<T>(
2336    request: T,
2337    response: Response<T>,
2338    session: MessageContext,
2339) -> Result<()>
2340where
2341    T: EntityMessage + RequestMessage,
2342{
2343    let project_id = ProjectId::from_proto(request.remote_entity_id());
2344
2345    let host_connection_id = session
2346        .db()
2347        .await
2348        .host_for_mutating_project_request(project_id, session.connection_id)
2349        .await?;
2350    let payload = session.forward_request(host_connection_id, request).await?;
2351    response.send(payload)?;
2352    Ok(())
2353}
2354
2355async fn lsp_query(
2356    request: proto::LspQuery,
2357    response: Response<proto::LspQuery>,
2358    session: MessageContext,
2359) -> Result<()> {
2360    let (name, should_write) = request.query_name_and_write_permissions();
2361    tracing::Span::current().record("lsp_query_request", name);
2362    tracing::info!("lsp_query message received");
2363    if should_write {
2364        forward_mutating_project_request(request, response, session).await
2365    } else {
2366        forward_read_only_project_request(request, response, session).await
2367    }
2368}
2369
2370/// Notify other participants that a new buffer has been created
2371async fn create_buffer_for_peer(
2372    request: proto::CreateBufferForPeer,
2373    session: MessageContext,
2374) -> Result<()> {
2375    session
2376        .db()
2377        .await
2378        .check_user_is_project_host(
2379            ProjectId::from_proto(request.project_id),
2380            session.connection_id,
2381        )
2382        .await?;
2383    let peer_id = request.peer_id.context("invalid peer id")?;
2384    session
2385        .peer
2386        .forward_send(session.connection_id, peer_id.into(), request)?;
2387    Ok(())
2388}
2389
2390/// Notify other participants that a buffer has been updated. This is
2391/// allowed for guests as long as the update is limited to selections.
2392async fn update_buffer(
2393    request: proto::UpdateBuffer,
2394    response: Response<proto::UpdateBuffer>,
2395    session: MessageContext,
2396) -> Result<()> {
2397    let project_id = ProjectId::from_proto(request.project_id);
2398    let mut capability = Capability::ReadOnly;
2399
2400    for op in request.operations.iter() {
2401        match op.variant {
2402            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2403            Some(_) => capability = Capability::ReadWrite,
2404        }
2405    }
2406
2407    let host = {
2408        let guard = session
2409            .db()
2410            .await
2411            .connections_for_buffer_update(project_id, session.connection_id, capability)
2412            .await?;
2413
2414        let (host, guests) = &*guard;
2415
2416        broadcast(
2417            Some(session.connection_id),
2418            guests.clone(),
2419            |connection_id| {
2420                session
2421                    .peer
2422                    .forward_send(session.connection_id, connection_id, request.clone())
2423            },
2424        );
2425
2426        *host
2427    };
2428
2429    if host != session.connection_id {
2430        session.forward_request(host, request.clone()).await?;
2431    }
2432
2433    response.send(proto::Ack {})?;
2434    Ok(())
2435}
2436
2437async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2438    let project_id = ProjectId::from_proto(message.project_id);
2439
2440    let operation = message.operation.as_ref().context("invalid operation")?;
2441    let capability = match operation.variant.as_ref() {
2442        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2443            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2444                match buffer_op.variant {
2445                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2446                        Capability::ReadOnly
2447                    }
2448                    _ => Capability::ReadWrite,
2449                }
2450            } else {
2451                Capability::ReadWrite
2452            }
2453        }
2454        Some(_) => Capability::ReadWrite,
2455        None => Capability::ReadOnly,
2456    };
2457
2458    let guard = session
2459        .db()
2460        .await
2461        .connections_for_buffer_update(project_id, session.connection_id, capability)
2462        .await?;
2463
2464    let (host, guests) = &*guard;
2465
2466    broadcast(
2467        Some(session.connection_id),
2468        guests.iter().chain([host]).copied(),
2469        |connection_id| {
2470            session
2471                .peer
2472                .forward_send(session.connection_id, connection_id, message.clone())
2473        },
2474    );
2475
2476    Ok(())
2477}
2478
2479/// Notify other participants that a project has been updated.
2480async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2481    request: T,
2482    session: MessageContext,
2483) -> Result<()> {
2484    let project_id = ProjectId::from_proto(request.remote_entity_id());
2485    let project_connection_ids = session
2486        .db()
2487        .await
2488        .project_connection_ids(project_id, session.connection_id, false)
2489        .await?;
2490
2491    broadcast(
2492        Some(session.connection_id),
2493        project_connection_ids.iter().copied(),
2494        |connection_id| {
2495            session
2496                .peer
2497                .forward_send(session.connection_id, connection_id, request.clone())
2498        },
2499    );
2500    Ok(())
2501}
2502
2503/// Start following another user in a call.
2504async fn follow(
2505    request: proto::Follow,
2506    response: Response<proto::Follow>,
2507    session: MessageContext,
2508) -> Result<()> {
2509    let room_id = RoomId::from_proto(request.room_id);
2510    let project_id = request.project_id.map(ProjectId::from_proto);
2511    let leader_id = request.leader_id.context("invalid leader id")?.into();
2512    let follower_id = session.connection_id;
2513
2514    session
2515        .db()
2516        .await
2517        .check_room_participants(room_id, leader_id, session.connection_id)
2518        .await?;
2519
2520    let response_payload = session.forward_request(leader_id, request).await?;
2521    response.send(response_payload)?;
2522
2523    if let Some(project_id) = project_id {
2524        let room = session
2525            .db()
2526            .await
2527            .follow(room_id, project_id, leader_id, follower_id)
2528            .await?;
2529        room_updated(&room, &session.peer);
2530    }
2531
2532    Ok(())
2533}
2534
2535/// Stop following another user in a call.
2536async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2537    let room_id = RoomId::from_proto(request.room_id);
2538    let project_id = request.project_id.map(ProjectId::from_proto);
2539    let leader_id = request.leader_id.context("invalid leader id")?.into();
2540    let follower_id = session.connection_id;
2541
2542    session
2543        .db()
2544        .await
2545        .check_room_participants(room_id, leader_id, session.connection_id)
2546        .await?;
2547
2548    session
2549        .peer
2550        .forward_send(session.connection_id, leader_id, request)?;
2551
2552    if let Some(project_id) = project_id {
2553        let room = session
2554            .db()
2555            .await
2556            .unfollow(room_id, project_id, leader_id, follower_id)
2557            .await?;
2558        room_updated(&room, &session.peer);
2559    }
2560
2561    Ok(())
2562}
2563
2564/// Notify everyone following you of your current location.
2565async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2566    let room_id = RoomId::from_proto(request.room_id);
2567    let database = session.db.lock().await;
2568
2569    let connection_ids = if let Some(project_id) = request.project_id {
2570        let project_id = ProjectId::from_proto(project_id);
2571        database
2572            .project_connection_ids(project_id, session.connection_id, true)
2573            .await?
2574    } else {
2575        database
2576            .room_connection_ids(room_id, session.connection_id)
2577            .await?
2578    };
2579
2580    // For now, don't send view update messages back to that view's current leader.
2581    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2582        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2583        _ => None,
2584    });
2585
2586    for connection_id in connection_ids.iter().cloned() {
2587        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2588            session
2589                .peer
2590                .forward_send(session.connection_id, connection_id, request.clone())?;
2591        }
2592    }
2593    Ok(())
2594}
2595
2596/// Get public data about users.
2597async fn get_users(
2598    request: proto::GetUsers,
2599    response: Response<proto::GetUsers>,
2600    session: MessageContext,
2601) -> Result<()> {
2602    let user_ids = request
2603        .user_ids
2604        .into_iter()
2605        .map(UserId::from_proto)
2606        .collect();
2607    let users = session
2608        .db()
2609        .await
2610        .get_users_by_ids(user_ids)
2611        .await?
2612        .into_iter()
2613        .map(|user| proto::User {
2614            id: user.id.to_proto(),
2615            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2616            github_login: user.github_login,
2617            name: user.name,
2618        })
2619        .collect();
2620    response.send(proto::UsersResponse { users })?;
2621    Ok(())
2622}
2623
2624/// Search for users (to invite) buy Github login
2625async fn fuzzy_search_users(
2626    request: proto::FuzzySearchUsers,
2627    response: Response<proto::FuzzySearchUsers>,
2628    session: MessageContext,
2629) -> Result<()> {
2630    let query = request.query;
2631    let users = match query.len() {
2632        0 => vec![],
2633        1 | 2 => session
2634            .db()
2635            .await
2636            .get_user_by_github_login(&query)
2637            .await?
2638            .into_iter()
2639            .collect(),
2640        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2641    };
2642    let users = users
2643        .into_iter()
2644        .filter(|user| user.id != session.user_id())
2645        .map(|user| proto::User {
2646            id: user.id.to_proto(),
2647            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2648            github_login: user.github_login,
2649            name: user.name,
2650        })
2651        .collect();
2652    response.send(proto::UsersResponse { users })?;
2653    Ok(())
2654}
2655
2656/// Send a contact request to another user.
2657async fn request_contact(
2658    request: proto::RequestContact,
2659    response: Response<proto::RequestContact>,
2660    session: MessageContext,
2661) -> Result<()> {
2662    let requester_id = session.user_id();
2663    let responder_id = UserId::from_proto(request.responder_id);
2664    if requester_id == responder_id {
2665        return Err(anyhow!("cannot add yourself as a contact"))?;
2666    }
2667
2668    let notifications = session
2669        .db()
2670        .await
2671        .send_contact_request(requester_id, responder_id)
2672        .await?;
2673
2674    // Update outgoing contact requests of requester
2675    let mut update = proto::UpdateContacts::default();
2676    update.outgoing_requests.push(responder_id.to_proto());
2677    for connection_id in session
2678        .connection_pool()
2679        .await
2680        .user_connection_ids(requester_id)
2681    {
2682        session.peer.send(connection_id, update.clone())?;
2683    }
2684
2685    // Update incoming contact requests of responder
2686    let mut update = proto::UpdateContacts::default();
2687    update
2688        .incoming_requests
2689        .push(proto::IncomingContactRequest {
2690            requester_id: requester_id.to_proto(),
2691        });
2692    let connection_pool = session.connection_pool().await;
2693    for connection_id in connection_pool.user_connection_ids(responder_id) {
2694        session.peer.send(connection_id, update.clone())?;
2695    }
2696
2697    send_notifications(&connection_pool, &session.peer, notifications);
2698
2699    response.send(proto::Ack {})?;
2700    Ok(())
2701}
2702
2703/// Accept or decline a contact request
2704async fn respond_to_contact_request(
2705    request: proto::RespondToContactRequest,
2706    response: Response<proto::RespondToContactRequest>,
2707    session: MessageContext,
2708) -> Result<()> {
2709    let responder_id = session.user_id();
2710    let requester_id = UserId::from_proto(request.requester_id);
2711    let db = session.db().await;
2712    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2713        db.dismiss_contact_notification(responder_id, requester_id)
2714            .await?;
2715    } else {
2716        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2717
2718        let notifications = db
2719            .respond_to_contact_request(responder_id, requester_id, accept)
2720            .await?;
2721        let requester_busy = db.is_user_busy(requester_id).await?;
2722        let responder_busy = db.is_user_busy(responder_id).await?;
2723
2724        let pool = session.connection_pool().await;
2725        // Update responder with new contact
2726        let mut update = proto::UpdateContacts::default();
2727        if accept {
2728            update
2729                .contacts
2730                .push(contact_for_user(requester_id, requester_busy, &pool));
2731        }
2732        update
2733            .remove_incoming_requests
2734            .push(requester_id.to_proto());
2735        for connection_id in pool.user_connection_ids(responder_id) {
2736            session.peer.send(connection_id, update.clone())?;
2737        }
2738
2739        // Update requester with new contact
2740        let mut update = proto::UpdateContacts::default();
2741        if accept {
2742            update
2743                .contacts
2744                .push(contact_for_user(responder_id, responder_busy, &pool));
2745        }
2746        update
2747            .remove_outgoing_requests
2748            .push(responder_id.to_proto());
2749
2750        for connection_id in pool.user_connection_ids(requester_id) {
2751            session.peer.send(connection_id, update.clone())?;
2752        }
2753
2754        send_notifications(&pool, &session.peer, notifications);
2755    }
2756
2757    response.send(proto::Ack {})?;
2758    Ok(())
2759}
2760
2761/// Remove a contact.
2762async fn remove_contact(
2763    request: proto::RemoveContact,
2764    response: Response<proto::RemoveContact>,
2765    session: MessageContext,
2766) -> Result<()> {
2767    let requester_id = session.user_id();
2768    let responder_id = UserId::from_proto(request.user_id);
2769    let db = session.db().await;
2770    let (contact_accepted, deleted_notification_id) =
2771        db.remove_contact(requester_id, responder_id).await?;
2772
2773    let pool = session.connection_pool().await;
2774    // Update outgoing contact requests of requester
2775    let mut update = proto::UpdateContacts::default();
2776    if contact_accepted {
2777        update.remove_contacts.push(responder_id.to_proto());
2778    } else {
2779        update
2780            .remove_outgoing_requests
2781            .push(responder_id.to_proto());
2782    }
2783    for connection_id in pool.user_connection_ids(requester_id) {
2784        session.peer.send(connection_id, update.clone())?;
2785    }
2786
2787    // Update incoming contact requests of responder
2788    let mut update = proto::UpdateContacts::default();
2789    if contact_accepted {
2790        update.remove_contacts.push(requester_id.to_proto());
2791    } else {
2792        update
2793            .remove_incoming_requests
2794            .push(requester_id.to_proto());
2795    }
2796    for connection_id in pool.user_connection_ids(responder_id) {
2797        session.peer.send(connection_id, update.clone())?;
2798        if let Some(notification_id) = deleted_notification_id {
2799            session.peer.send(
2800                connection_id,
2801                proto::DeleteNotification {
2802                    notification_id: notification_id.to_proto(),
2803                },
2804            )?;
2805        }
2806    }
2807
2808    response.send(proto::Ack {})?;
2809    Ok(())
2810}
2811
2812const fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2813    version.0.minor() < 139
2814}
2815
2816async fn subscribe_to_channels(
2817    _: proto::SubscribeToChannels,
2818    session: MessageContext,
2819) -> Result<()> {
2820    subscribe_user_to_channels(session.user_id(), &session).await?;
2821    Ok(())
2822}
2823
2824async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2825    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2826    let mut pool = session.connection_pool().await;
2827    for membership in &channels_for_user.channel_memberships {
2828        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2829    }
2830    session.peer.send(
2831        session.connection_id,
2832        build_update_user_channels(&channels_for_user),
2833    )?;
2834    session.peer.send(
2835        session.connection_id,
2836        build_channels_update(channels_for_user),
2837    )?;
2838    Ok(())
2839}
2840
2841/// Creates a new channel.
2842async fn create_channel(
2843    request: proto::CreateChannel,
2844    response: Response<proto::CreateChannel>,
2845    session: MessageContext,
2846) -> Result<()> {
2847    let db = session.db().await;
2848
2849    let parent_id = request.parent_id.map(ChannelId::from_proto);
2850    let (channel, membership) = db
2851        .create_channel(&request.name, parent_id, session.user_id())
2852        .await?;
2853
2854    let root_id = channel.root_id();
2855    let channel = Channel::from_model(channel);
2856
2857    response.send(proto::CreateChannelResponse {
2858        channel: Some(channel.to_proto()),
2859        parent_id: request.parent_id,
2860    })?;
2861
2862    let mut connection_pool = session.connection_pool().await;
2863    if let Some(membership) = membership {
2864        connection_pool.subscribe_to_channel(
2865            membership.user_id,
2866            membership.channel_id,
2867            membership.role,
2868        );
2869        let update = proto::UpdateUserChannels {
2870            channel_memberships: vec![proto::ChannelMembership {
2871                channel_id: membership.channel_id.to_proto(),
2872                role: membership.role.into(),
2873            }],
2874            ..Default::default()
2875        };
2876        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2877            session.peer.send(connection_id, update.clone())?;
2878        }
2879    }
2880
2881    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2882        if !role.can_see_channel(channel.visibility) {
2883            continue;
2884        }
2885
2886        let update = proto::UpdateChannels {
2887            channels: vec![channel.to_proto()],
2888            ..Default::default()
2889        };
2890        session.peer.send(connection_id, update.clone())?;
2891    }
2892
2893    Ok(())
2894}
2895
2896/// Delete a channel
2897async fn delete_channel(
2898    request: proto::DeleteChannel,
2899    response: Response<proto::DeleteChannel>,
2900    session: MessageContext,
2901) -> Result<()> {
2902    let db = session.db().await;
2903
2904    let channel_id = request.channel_id;
2905    let (root_channel, removed_channels) = db
2906        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2907        .await?;
2908    response.send(proto::Ack {})?;
2909
2910    // Notify members of removed channels
2911    let mut update = proto::UpdateChannels::default();
2912    update
2913        .delete_channels
2914        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2915
2916    let connection_pool = session.connection_pool().await;
2917    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2918        session.peer.send(connection_id, update.clone())?;
2919    }
2920
2921    Ok(())
2922}
2923
2924/// Invite someone to join a channel.
2925async fn invite_channel_member(
2926    request: proto::InviteChannelMember,
2927    response: Response<proto::InviteChannelMember>,
2928    session: MessageContext,
2929) -> Result<()> {
2930    let db = session.db().await;
2931    let channel_id = ChannelId::from_proto(request.channel_id);
2932    let invitee_id = UserId::from_proto(request.user_id);
2933    let InviteMemberResult {
2934        channel,
2935        notifications,
2936    } = db
2937        .invite_channel_member(
2938            channel_id,
2939            invitee_id,
2940            session.user_id(),
2941            request.role().into(),
2942        )
2943        .await?;
2944
2945    let update = proto::UpdateChannels {
2946        channel_invitations: vec![channel.to_proto()],
2947        ..Default::default()
2948    };
2949
2950    let connection_pool = session.connection_pool().await;
2951    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2952        session.peer.send(connection_id, update.clone())?;
2953    }
2954
2955    send_notifications(&connection_pool, &session.peer, notifications);
2956
2957    response.send(proto::Ack {})?;
2958    Ok(())
2959}
2960
2961/// remove someone from a channel
2962async fn remove_channel_member(
2963    request: proto::RemoveChannelMember,
2964    response: Response<proto::RemoveChannelMember>,
2965    session: MessageContext,
2966) -> Result<()> {
2967    let db = session.db().await;
2968    let channel_id = ChannelId::from_proto(request.channel_id);
2969    let member_id = UserId::from_proto(request.user_id);
2970
2971    let RemoveChannelMemberResult {
2972        membership_update,
2973        notification_id,
2974    } = db
2975        .remove_channel_member(channel_id, member_id, session.user_id())
2976        .await?;
2977
2978    let mut connection_pool = session.connection_pool().await;
2979    notify_membership_updated(
2980        &mut connection_pool,
2981        membership_update,
2982        member_id,
2983        &session.peer,
2984    );
2985    for connection_id in connection_pool.user_connection_ids(member_id) {
2986        if let Some(notification_id) = notification_id {
2987            session
2988                .peer
2989                .send(
2990                    connection_id,
2991                    proto::DeleteNotification {
2992                        notification_id: notification_id.to_proto(),
2993                    },
2994                )
2995                .trace_err();
2996        }
2997    }
2998
2999    response.send(proto::Ack {})?;
3000    Ok(())
3001}
3002
3003/// Toggle the channel between public and private.
3004/// Care is taken to maintain the invariant that public channels only descend from public channels,
3005/// (though members-only channels can appear at any point in the hierarchy).
3006async fn set_channel_visibility(
3007    request: proto::SetChannelVisibility,
3008    response: Response<proto::SetChannelVisibility>,
3009    session: MessageContext,
3010) -> Result<()> {
3011    let db = session.db().await;
3012    let channel_id = ChannelId::from_proto(request.channel_id);
3013    let visibility = request.visibility().into();
3014
3015    let channel_model = db
3016        .set_channel_visibility(channel_id, visibility, session.user_id())
3017        .await?;
3018    let root_id = channel_model.root_id();
3019    let channel = Channel::from_model(channel_model);
3020
3021    let mut connection_pool = session.connection_pool().await;
3022    for (user_id, role) in connection_pool
3023        .channel_user_ids(root_id)
3024        .collect::<Vec<_>>()
3025        .into_iter()
3026    {
3027        let update = if role.can_see_channel(channel.visibility) {
3028            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3029            proto::UpdateChannels {
3030                channels: vec![channel.to_proto()],
3031                ..Default::default()
3032            }
3033        } else {
3034            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3035            proto::UpdateChannels {
3036                delete_channels: vec![channel.id.to_proto()],
3037                ..Default::default()
3038            }
3039        };
3040
3041        for connection_id in connection_pool.user_connection_ids(user_id) {
3042            session.peer.send(connection_id, update.clone())?;
3043        }
3044    }
3045
3046    response.send(proto::Ack {})?;
3047    Ok(())
3048}
3049
3050/// Alter the role for a user in the channel.
3051async fn set_channel_member_role(
3052    request: proto::SetChannelMemberRole,
3053    response: Response<proto::SetChannelMemberRole>,
3054    session: MessageContext,
3055) -> Result<()> {
3056    let db = session.db().await;
3057    let channel_id = ChannelId::from_proto(request.channel_id);
3058    let member_id = UserId::from_proto(request.user_id);
3059    let result = db
3060        .set_channel_member_role(
3061            channel_id,
3062            session.user_id(),
3063            member_id,
3064            request.role().into(),
3065        )
3066        .await?;
3067
3068    match result {
3069        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3070            let mut connection_pool = session.connection_pool().await;
3071            notify_membership_updated(
3072                &mut connection_pool,
3073                membership_update,
3074                member_id,
3075                &session.peer,
3076            )
3077        }
3078        db::SetMemberRoleResult::InviteUpdated(channel) => {
3079            let update = proto::UpdateChannels {
3080                channel_invitations: vec![channel.to_proto()],
3081                ..Default::default()
3082            };
3083
3084            for connection_id in session
3085                .connection_pool()
3086                .await
3087                .user_connection_ids(member_id)
3088            {
3089                session.peer.send(connection_id, update.clone())?;
3090            }
3091        }
3092    }
3093
3094    response.send(proto::Ack {})?;
3095    Ok(())
3096}
3097
3098/// Change the name of a channel
3099async fn rename_channel(
3100    request: proto::RenameChannel,
3101    response: Response<proto::RenameChannel>,
3102    session: MessageContext,
3103) -> Result<()> {
3104    let db = session.db().await;
3105    let channel_id = ChannelId::from_proto(request.channel_id);
3106    let channel_model = db
3107        .rename_channel(channel_id, session.user_id(), &request.name)
3108        .await?;
3109    let root_id = channel_model.root_id();
3110    let channel = Channel::from_model(channel_model);
3111
3112    response.send(proto::RenameChannelResponse {
3113        channel: Some(channel.to_proto()),
3114    })?;
3115
3116    let connection_pool = session.connection_pool().await;
3117    let update = proto::UpdateChannels {
3118        channels: vec![channel.to_proto()],
3119        ..Default::default()
3120    };
3121    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3122        if role.can_see_channel(channel.visibility) {
3123            session.peer.send(connection_id, update.clone())?;
3124        }
3125    }
3126
3127    Ok(())
3128}
3129
3130/// Move a channel to a new parent.
3131async fn move_channel(
3132    request: proto::MoveChannel,
3133    response: Response<proto::MoveChannel>,
3134    session: MessageContext,
3135) -> Result<()> {
3136    let channel_id = ChannelId::from_proto(request.channel_id);
3137    let to = ChannelId::from_proto(request.to);
3138
3139    let (root_id, channels) = session
3140        .db()
3141        .await
3142        .move_channel(channel_id, to, session.user_id())
3143        .await?;
3144
3145    let connection_pool = session.connection_pool().await;
3146    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3147        let channels = channels
3148            .iter()
3149            .filter_map(|channel| {
3150                if role.can_see_channel(channel.visibility) {
3151                    Some(channel.to_proto())
3152                } else {
3153                    None
3154                }
3155            })
3156            .collect::<Vec<_>>();
3157        if channels.is_empty() {
3158            continue;
3159        }
3160
3161        let update = proto::UpdateChannels {
3162            channels,
3163            ..Default::default()
3164        };
3165
3166        session.peer.send(connection_id, update.clone())?;
3167    }
3168
3169    response.send(Ack {})?;
3170    Ok(())
3171}
3172
3173async fn reorder_channel(
3174    request: proto::ReorderChannel,
3175    response: Response<proto::ReorderChannel>,
3176    session: MessageContext,
3177) -> Result<()> {
3178    let channel_id = ChannelId::from_proto(request.channel_id);
3179    let direction = request.direction();
3180
3181    let updated_channels = session
3182        .db()
3183        .await
3184        .reorder_channel(channel_id, direction, session.user_id())
3185        .await?;
3186
3187    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3188        let connection_pool = session.connection_pool().await;
3189        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3190            let channels = updated_channels
3191                .iter()
3192                .filter_map(|channel| {
3193                    if role.can_see_channel(channel.visibility) {
3194                        Some(channel.to_proto())
3195                    } else {
3196                        None
3197                    }
3198                })
3199                .collect::<Vec<_>>();
3200
3201            if channels.is_empty() {
3202                continue;
3203            }
3204
3205            let update = proto::UpdateChannels {
3206                channels,
3207                ..Default::default()
3208            };
3209
3210            session.peer.send(connection_id, update.clone())?;
3211        }
3212    }
3213
3214    response.send(Ack {})?;
3215    Ok(())
3216}
3217
3218/// Get the list of channel members
3219async fn get_channel_members(
3220    request: proto::GetChannelMembers,
3221    response: Response<proto::GetChannelMembers>,
3222    session: MessageContext,
3223) -> Result<()> {
3224    let db = session.db().await;
3225    let channel_id = ChannelId::from_proto(request.channel_id);
3226    let limit = if request.limit == 0 {
3227        u16::MAX as u64
3228    } else {
3229        request.limit
3230    };
3231    let (members, users) = db
3232        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3233        .await?;
3234    response.send(proto::GetChannelMembersResponse { members, users })?;
3235    Ok(())
3236}
3237
3238/// Accept or decline a channel invitation.
3239async fn respond_to_channel_invite(
3240    request: proto::RespondToChannelInvite,
3241    response: Response<proto::RespondToChannelInvite>,
3242    session: MessageContext,
3243) -> Result<()> {
3244    let db = session.db().await;
3245    let channel_id = ChannelId::from_proto(request.channel_id);
3246    let RespondToChannelInvite {
3247        membership_update,
3248        notifications,
3249    } = db
3250        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3251        .await?;
3252
3253    let mut connection_pool = session.connection_pool().await;
3254    if let Some(membership_update) = membership_update {
3255        notify_membership_updated(
3256            &mut connection_pool,
3257            membership_update,
3258            session.user_id(),
3259            &session.peer,
3260        );
3261    } else {
3262        let update = proto::UpdateChannels {
3263            remove_channel_invitations: vec![channel_id.to_proto()],
3264            ..Default::default()
3265        };
3266
3267        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3268            session.peer.send(connection_id, update.clone())?;
3269        }
3270    };
3271
3272    send_notifications(&connection_pool, &session.peer, notifications);
3273
3274    response.send(proto::Ack {})?;
3275
3276    Ok(())
3277}
3278
3279/// Join the channels' room
3280async fn join_channel(
3281    request: proto::JoinChannel,
3282    response: Response<proto::JoinChannel>,
3283    session: MessageContext,
3284) -> Result<()> {
3285    let channel_id = ChannelId::from_proto(request.channel_id);
3286    join_channel_internal(channel_id, Box::new(response), session).await
3287}
3288
3289trait JoinChannelInternalResponse {
3290    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3291}
3292impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3293    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3294        Response::<proto::JoinChannel>::send(self, result)
3295    }
3296}
3297impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3298    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3299        Response::<proto::JoinRoom>::send(self, result)
3300    }
3301}
3302
3303async fn join_channel_internal(
3304    channel_id: ChannelId,
3305    response: Box<impl JoinChannelInternalResponse>,
3306    session: MessageContext,
3307) -> Result<()> {
3308    let joined_room = {
3309        let mut db = session.db().await;
3310        // If zed quits without leaving the room, and the user re-opens zed before the
3311        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3312        // room they were in.
3313        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3314            tracing::info!(
3315                stale_connection_id = %connection,
3316                "cleaning up stale connection",
3317            );
3318            drop(db);
3319            leave_room_for_session(&session, connection).await?;
3320            db = session.db().await;
3321        }
3322
3323        let (joined_room, membership_updated, role) = db
3324            .join_channel(channel_id, session.user_id(), session.connection_id)
3325            .await?;
3326
3327        let live_kit_connection_info =
3328            session
3329                .app_state
3330                .livekit_client
3331                .as_ref()
3332                .and_then(|live_kit| {
3333                    let (can_publish, token) = if role == ChannelRole::Guest {
3334                        (
3335                            false,
3336                            live_kit
3337                                .guest_token(
3338                                    &joined_room.room.livekit_room,
3339                                    &session.user_id().to_string(),
3340                                )
3341                                .trace_err()?,
3342                        )
3343                    } else {
3344                        (
3345                            true,
3346                            live_kit
3347                                .room_token(
3348                                    &joined_room.room.livekit_room,
3349                                    &session.user_id().to_string(),
3350                                )
3351                                .trace_err()?,
3352                        )
3353                    };
3354
3355                    Some(LiveKitConnectionInfo {
3356                        server_url: live_kit.url().into(),
3357                        token,
3358                        can_publish,
3359                    })
3360                });
3361
3362        response.send(proto::JoinRoomResponse {
3363            room: Some(joined_room.room.clone()),
3364            channel_id: joined_room
3365                .channel
3366                .as_ref()
3367                .map(|channel| channel.id.to_proto()),
3368            live_kit_connection_info,
3369        })?;
3370
3371        let mut connection_pool = session.connection_pool().await;
3372        if let Some(membership_updated) = membership_updated {
3373            notify_membership_updated(
3374                &mut connection_pool,
3375                membership_updated,
3376                session.user_id(),
3377                &session.peer,
3378            );
3379        }
3380
3381        room_updated(&joined_room.room, &session.peer);
3382
3383        joined_room
3384    };
3385
3386    channel_updated(
3387        &joined_room.channel.context("channel not returned")?,
3388        &joined_room.room,
3389        &session.peer,
3390        &*session.connection_pool().await,
3391    );
3392
3393    update_user_contacts(session.user_id(), &session).await?;
3394    Ok(())
3395}
3396
3397/// Start editing the channel notes
3398async fn join_channel_buffer(
3399    request: proto::JoinChannelBuffer,
3400    response: Response<proto::JoinChannelBuffer>,
3401    session: MessageContext,
3402) -> Result<()> {
3403    let db = session.db().await;
3404    let channel_id = ChannelId::from_proto(request.channel_id);
3405
3406    let open_response = db
3407        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3408        .await?;
3409
3410    let collaborators = open_response.collaborators.clone();
3411    response.send(open_response)?;
3412
3413    let update = UpdateChannelBufferCollaborators {
3414        channel_id: channel_id.to_proto(),
3415        collaborators: collaborators.clone(),
3416    };
3417    channel_buffer_updated(
3418        session.connection_id,
3419        collaborators
3420            .iter()
3421            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3422        &update,
3423        &session.peer,
3424    );
3425
3426    Ok(())
3427}
3428
3429/// Edit the channel notes
3430async fn update_channel_buffer(
3431    request: proto::UpdateChannelBuffer,
3432    session: MessageContext,
3433) -> Result<()> {
3434    let db = session.db().await;
3435    let channel_id = ChannelId::from_proto(request.channel_id);
3436
3437    let (collaborators, epoch, version) = db
3438        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3439        .await?;
3440
3441    channel_buffer_updated(
3442        session.connection_id,
3443        collaborators.clone(),
3444        &proto::UpdateChannelBuffer {
3445            channel_id: channel_id.to_proto(),
3446            operations: request.operations,
3447        },
3448        &session.peer,
3449    );
3450
3451    let pool = &*session.connection_pool().await;
3452
3453    let non_collaborators =
3454        pool.channel_connection_ids(channel_id)
3455            .filter_map(|(connection_id, _)| {
3456                if collaborators.contains(&connection_id) {
3457                    None
3458                } else {
3459                    Some(connection_id)
3460                }
3461            });
3462
3463    broadcast(None, non_collaborators, |peer_id| {
3464        session.peer.send(
3465            peer_id,
3466            proto::UpdateChannels {
3467                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3468                    channel_id: channel_id.to_proto(),
3469                    epoch: epoch as u64,
3470                    version: version.clone(),
3471                }],
3472                ..Default::default()
3473            },
3474        )
3475    });
3476
3477    Ok(())
3478}
3479
3480/// Rejoin the channel notes after a connection blip
3481async fn rejoin_channel_buffers(
3482    request: proto::RejoinChannelBuffers,
3483    response: Response<proto::RejoinChannelBuffers>,
3484    session: MessageContext,
3485) -> Result<()> {
3486    let db = session.db().await;
3487    let buffers = db
3488        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3489        .await?;
3490
3491    for rejoined_buffer in &buffers {
3492        let collaborators_to_notify = rejoined_buffer
3493            .buffer
3494            .collaborators
3495            .iter()
3496            .filter_map(|c| Some(c.peer_id?.into()));
3497        channel_buffer_updated(
3498            session.connection_id,
3499            collaborators_to_notify,
3500            &proto::UpdateChannelBufferCollaborators {
3501                channel_id: rejoined_buffer.buffer.channel_id,
3502                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3503            },
3504            &session.peer,
3505        );
3506    }
3507
3508    response.send(proto::RejoinChannelBuffersResponse {
3509        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3510    })?;
3511
3512    Ok(())
3513}
3514
3515/// Stop editing the channel notes
3516async fn leave_channel_buffer(
3517    request: proto::LeaveChannelBuffer,
3518    response: Response<proto::LeaveChannelBuffer>,
3519    session: MessageContext,
3520) -> Result<()> {
3521    let db = session.db().await;
3522    let channel_id = ChannelId::from_proto(request.channel_id);
3523
3524    let left_buffer = db
3525        .leave_channel_buffer(channel_id, session.connection_id)
3526        .await?;
3527
3528    response.send(Ack {})?;
3529
3530    channel_buffer_updated(
3531        session.connection_id,
3532        left_buffer.connections,
3533        &proto::UpdateChannelBufferCollaborators {
3534            channel_id: channel_id.to_proto(),
3535            collaborators: left_buffer.collaborators,
3536        },
3537        &session.peer,
3538    );
3539
3540    Ok(())
3541}
3542
3543fn channel_buffer_updated<T: EnvelopedMessage>(
3544    sender_id: ConnectionId,
3545    collaborators: impl IntoIterator<Item = ConnectionId>,
3546    message: &T,
3547    peer: &Peer,
3548) {
3549    broadcast(Some(sender_id), collaborators, |peer_id| {
3550        peer.send(peer_id, message.clone())
3551    });
3552}
3553
3554fn send_notifications(
3555    connection_pool: &ConnectionPool,
3556    peer: &Peer,
3557    notifications: db::NotificationBatch,
3558) {
3559    for (user_id, notification) in notifications {
3560        for connection_id in connection_pool.user_connection_ids(user_id) {
3561            if let Err(error) = peer.send(
3562                connection_id,
3563                proto::AddNotification {
3564                    notification: Some(notification.clone()),
3565                },
3566            ) {
3567                tracing::error!(
3568                    "failed to send notification to {:?} {}",
3569                    connection_id,
3570                    error
3571                );
3572            }
3573        }
3574    }
3575}
3576
3577/// Send a message to the channel
3578async fn send_channel_message(
3579    _request: proto::SendChannelMessage,
3580    _response: Response<proto::SendChannelMessage>,
3581    _session: MessageContext,
3582) -> Result<()> {
3583    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3584}
3585
3586/// Delete a channel message
3587async fn remove_channel_message(
3588    _request: proto::RemoveChannelMessage,
3589    _response: Response<proto::RemoveChannelMessage>,
3590    _session: MessageContext,
3591) -> Result<()> {
3592    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3593}
3594
3595async fn update_channel_message(
3596    _request: proto::UpdateChannelMessage,
3597    _response: Response<proto::UpdateChannelMessage>,
3598    _session: MessageContext,
3599) -> Result<()> {
3600    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3601}
3602
3603/// Mark a channel message as read
3604async fn acknowledge_channel_message(
3605    _request: proto::AckChannelMessage,
3606    _session: MessageContext,
3607) -> Result<()> {
3608    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3609}
3610
3611/// Mark a buffer version as synced
3612async fn acknowledge_buffer_version(
3613    request: proto::AckBufferOperation,
3614    session: MessageContext,
3615) -> Result<()> {
3616    let buffer_id = BufferId::from_proto(request.buffer_id);
3617    session
3618        .db()
3619        .await
3620        .observe_buffer_version(
3621            buffer_id,
3622            session.user_id(),
3623            request.epoch as i32,
3624            &request.version,
3625        )
3626        .await?;
3627    Ok(())
3628}
3629
3630/// Get a Supermaven API key for the user
3631async fn get_supermaven_api_key(
3632    _request: proto::GetSupermavenApiKey,
3633    response: Response<proto::GetSupermavenApiKey>,
3634    session: MessageContext,
3635) -> Result<()> {
3636    let user_id: String = session.user_id().to_string();
3637    if !session.is_staff() {
3638        return Err(anyhow!("supermaven not enabled for this account"))?;
3639    }
3640
3641    let email = session.email().context("user must have an email")?;
3642
3643    let supermaven_admin_api = session
3644        .supermaven_client
3645        .as_ref()
3646        .context("supermaven not configured")?;
3647
3648    let result = supermaven_admin_api
3649        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3650        .await?;
3651
3652    response.send(proto::GetSupermavenApiKeyResponse {
3653        api_key: result.api_key,
3654    })?;
3655
3656    Ok(())
3657}
3658
3659/// Start receiving chat updates for a channel
3660async fn join_channel_chat(
3661    _request: proto::JoinChannelChat,
3662    _response: Response<proto::JoinChannelChat>,
3663    _session: MessageContext,
3664) -> Result<()> {
3665    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3666}
3667
3668/// Stop receiving chat updates for a channel
3669async fn leave_channel_chat(
3670    _request: proto::LeaveChannelChat,
3671    _session: MessageContext,
3672) -> Result<()> {
3673    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3674}
3675
3676/// Retrieve the chat history for a channel
3677async fn get_channel_messages(
3678    _request: proto::GetChannelMessages,
3679    _response: Response<proto::GetChannelMessages>,
3680    _session: MessageContext,
3681) -> Result<()> {
3682    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3683}
3684
3685/// Retrieve specific chat messages
3686async fn get_channel_messages_by_id(
3687    _request: proto::GetChannelMessagesById,
3688    _response: Response<proto::GetChannelMessagesById>,
3689    _session: MessageContext,
3690) -> Result<()> {
3691    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3692}
3693
3694/// Retrieve the current users notifications
3695async fn get_notifications(
3696    request: proto::GetNotifications,
3697    response: Response<proto::GetNotifications>,
3698    session: MessageContext,
3699) -> Result<()> {
3700    let notifications = session
3701        .db()
3702        .await
3703        .get_notifications(
3704            session.user_id(),
3705            NOTIFICATION_COUNT_PER_PAGE,
3706            request.before_id.map(db::NotificationId::from_proto),
3707        )
3708        .await?;
3709    response.send(proto::GetNotificationsResponse {
3710        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3711        notifications,
3712    })?;
3713    Ok(())
3714}
3715
3716/// Mark notifications as read
3717async fn mark_notification_as_read(
3718    request: proto::MarkNotificationRead,
3719    response: Response<proto::MarkNotificationRead>,
3720    session: MessageContext,
3721) -> Result<()> {
3722    let database = &session.db().await;
3723    let notifications = database
3724        .mark_notification_as_read_by_id(
3725            session.user_id(),
3726            NotificationId::from_proto(request.notification_id),
3727        )
3728        .await?;
3729    send_notifications(
3730        &*session.connection_pool().await,
3731        &session.peer,
3732        notifications,
3733    );
3734    response.send(proto::Ack {})?;
3735    Ok(())
3736}
3737
3738fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3739    let message = match message {
3740        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3741        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3742        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3743        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3744        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3745            code: frame.code.into(),
3746            reason: frame.reason.as_str().to_owned().into(),
3747        })),
3748        // We should never receive a frame while reading the message, according
3749        // to the `tungstenite` maintainers:
3750        //
3751        // > It cannot occur when you read messages from the WebSocket, but it
3752        // > can be used when you want to send the raw frames (e.g. you want to
3753        // > send the frames to the WebSocket without composing the full message first).
3754        // >
3755        // > — https://github.com/snapview/tungstenite-rs/issues/268
3756        TungsteniteMessage::Frame(_) => {
3757            bail!("received an unexpected frame while reading the message")
3758        }
3759    };
3760
3761    Ok(message)
3762}
3763
3764fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3765    match message {
3766        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3767        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3768        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3769        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3770        AxumMessage::Close(frame) => {
3771            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3772                code: frame.code.into(),
3773                reason: frame.reason.as_ref().into(),
3774            }))
3775        }
3776    }
3777}
3778
3779fn notify_membership_updated(
3780    connection_pool: &mut ConnectionPool,
3781    result: MembershipUpdated,
3782    user_id: UserId,
3783    peer: &Peer,
3784) {
3785    for membership in &result.new_channels.channel_memberships {
3786        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3787    }
3788    for channel_id in &result.removed_channels {
3789        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3790    }
3791
3792    let user_channels_update = proto::UpdateUserChannels {
3793        channel_memberships: result
3794            .new_channels
3795            .channel_memberships
3796            .iter()
3797            .map(|cm| proto::ChannelMembership {
3798                channel_id: cm.channel_id.to_proto(),
3799                role: cm.role.into(),
3800            })
3801            .collect(),
3802        ..Default::default()
3803    };
3804
3805    let mut update = build_channels_update(result.new_channels);
3806    update.delete_channels = result
3807        .removed_channels
3808        .into_iter()
3809        .map(|id| id.to_proto())
3810        .collect();
3811    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3812
3813    for connection_id in connection_pool.user_connection_ids(user_id) {
3814        peer.send(connection_id, user_channels_update.clone())
3815            .trace_err();
3816        peer.send(connection_id, update.clone()).trace_err();
3817    }
3818}
3819
3820fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3821    proto::UpdateUserChannels {
3822        channel_memberships: channels
3823            .channel_memberships
3824            .iter()
3825            .map(|m| proto::ChannelMembership {
3826                channel_id: m.channel_id.to_proto(),
3827                role: m.role.into(),
3828            })
3829            .collect(),
3830        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3831    }
3832}
3833
3834fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3835    let mut update = proto::UpdateChannels::default();
3836
3837    for channel in channels.channels {
3838        update.channels.push(channel.to_proto());
3839    }
3840
3841    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3842
3843    for (channel_id, participants) in channels.channel_participants {
3844        update
3845            .channel_participants
3846            .push(proto::ChannelParticipants {
3847                channel_id: channel_id.to_proto(),
3848                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3849            });
3850    }
3851
3852    for channel in channels.invited_channels {
3853        update.channel_invitations.push(channel.to_proto());
3854    }
3855
3856    update
3857}
3858
3859fn build_initial_contacts_update(
3860    contacts: Vec<db::Contact>,
3861    pool: &ConnectionPool,
3862) -> proto::UpdateContacts {
3863    let mut update = proto::UpdateContacts::default();
3864
3865    for contact in contacts {
3866        match contact {
3867            db::Contact::Accepted { user_id, busy } => {
3868                update.contacts.push(contact_for_user(user_id, busy, pool));
3869            }
3870            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3871            db::Contact::Incoming { user_id } => {
3872                update
3873                    .incoming_requests
3874                    .push(proto::IncomingContactRequest {
3875                        requester_id: user_id.to_proto(),
3876                    })
3877            }
3878        }
3879    }
3880
3881    update
3882}
3883
3884fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3885    proto::Contact {
3886        user_id: user_id.to_proto(),
3887        online: pool.is_user_online(user_id),
3888        busy,
3889    }
3890}
3891
3892fn room_updated(room: &proto::Room, peer: &Peer) {
3893    broadcast(
3894        None,
3895        room.participants
3896            .iter()
3897            .filter_map(|participant| Some(participant.peer_id?.into())),
3898        |peer_id| {
3899            peer.send(
3900                peer_id,
3901                proto::RoomUpdated {
3902                    room: Some(room.clone()),
3903                },
3904            )
3905        },
3906    );
3907}
3908
3909fn channel_updated(
3910    channel: &db::channel::Model,
3911    room: &proto::Room,
3912    peer: &Peer,
3913    pool: &ConnectionPool,
3914) {
3915    let participants = room
3916        .participants
3917        .iter()
3918        .map(|p| p.user_id)
3919        .collect::<Vec<_>>();
3920
3921    broadcast(
3922        None,
3923        pool.channel_connection_ids(channel.root_id())
3924            .filter_map(|(channel_id, role)| {
3925                role.can_see_channel(channel.visibility)
3926                    .then_some(channel_id)
3927            }),
3928        |peer_id| {
3929            peer.send(
3930                peer_id,
3931                proto::UpdateChannels {
3932                    channel_participants: vec![proto::ChannelParticipants {
3933                        channel_id: channel.id.to_proto(),
3934                        participant_user_ids: participants.clone(),
3935                    }],
3936                    ..Default::default()
3937                },
3938            )
3939        },
3940    );
3941}
3942
3943async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3944    let db = session.db().await;
3945
3946    let contacts = db.get_contacts(user_id).await?;
3947    let busy = db.is_user_busy(user_id).await?;
3948
3949    let pool = session.connection_pool().await;
3950    let updated_contact = contact_for_user(user_id, busy, &pool);
3951    for contact in contacts {
3952        if let db::Contact::Accepted {
3953            user_id: contact_user_id,
3954            ..
3955        } = contact
3956        {
3957            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3958                session
3959                    .peer
3960                    .send(
3961                        contact_conn_id,
3962                        proto::UpdateContacts {
3963                            contacts: vec![updated_contact.clone()],
3964                            remove_contacts: Default::default(),
3965                            incoming_requests: Default::default(),
3966                            remove_incoming_requests: Default::default(),
3967                            outgoing_requests: Default::default(),
3968                            remove_outgoing_requests: Default::default(),
3969                        },
3970                    )
3971                    .trace_err();
3972            }
3973        }
3974    }
3975    Ok(())
3976}
3977
3978async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3979    let mut contacts_to_update = HashSet::default();
3980
3981    let room_id;
3982    let canceled_calls_to_user_ids;
3983    let livekit_room;
3984    let delete_livekit_room;
3985    let room;
3986    let channel;
3987
3988    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3989        contacts_to_update.insert(session.user_id());
3990
3991        for project in left_room.left_projects.values() {
3992            project_left(project, session);
3993        }
3994
3995        room_id = RoomId::from_proto(left_room.room.id);
3996        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3997        livekit_room = mem::take(&mut left_room.room.livekit_room);
3998        delete_livekit_room = left_room.deleted;
3999        room = mem::take(&mut left_room.room);
4000        channel = mem::take(&mut left_room.channel);
4001
4002        room_updated(&room, &session.peer);
4003    } else {
4004        return Ok(());
4005    }
4006
4007    if let Some(channel) = channel {
4008        channel_updated(
4009            &channel,
4010            &room,
4011            &session.peer,
4012            &*session.connection_pool().await,
4013        );
4014    }
4015
4016    {
4017        let pool = session.connection_pool().await;
4018        for canceled_user_id in canceled_calls_to_user_ids {
4019            for connection_id in pool.user_connection_ids(canceled_user_id) {
4020                session
4021                    .peer
4022                    .send(
4023                        connection_id,
4024                        proto::CallCanceled {
4025                            room_id: room_id.to_proto(),
4026                        },
4027                    )
4028                    .trace_err();
4029            }
4030            contacts_to_update.insert(canceled_user_id);
4031        }
4032    }
4033
4034    for contact_user_id in contacts_to_update {
4035        update_user_contacts(contact_user_id, session).await?;
4036    }
4037
4038    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4039        live_kit
4040            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4041            .await
4042            .trace_err();
4043
4044        if delete_livekit_room {
4045            live_kit.delete_room(livekit_room).await.trace_err();
4046        }
4047    }
4048
4049    Ok(())
4050}
4051
4052async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4053    let left_channel_buffers = session
4054        .db()
4055        .await
4056        .leave_channel_buffers(session.connection_id)
4057        .await?;
4058
4059    for left_buffer in left_channel_buffers {
4060        channel_buffer_updated(
4061            session.connection_id,
4062            left_buffer.connections,
4063            &proto::UpdateChannelBufferCollaborators {
4064                channel_id: left_buffer.channel_id.to_proto(),
4065                collaborators: left_buffer.collaborators,
4066            },
4067            &session.peer,
4068        );
4069    }
4070
4071    Ok(())
4072}
4073
4074fn project_left(project: &db::LeftProject, session: &Session) {
4075    for connection_id in &project.connection_ids {
4076        if project.should_unshare {
4077            session
4078                .peer
4079                .send(
4080                    *connection_id,
4081                    proto::UnshareProject {
4082                        project_id: project.id.to_proto(),
4083                    },
4084                )
4085                .trace_err();
4086        } else {
4087            session
4088                .peer
4089                .send(
4090                    *connection_id,
4091                    proto::RemoveProjectCollaborator {
4092                        project_id: project.id.to_proto(),
4093                        peer_id: Some(session.connection_id.into()),
4094                    },
4095                )
4096                .trace_err();
4097        }
4098    }
4099}
4100
4101pub trait ResultExt {
4102    type Ok;
4103
4104    fn trace_err(self) -> Option<Self::Ok>;
4105}
4106
4107impl<T, E> ResultExt for Result<T, E>
4108where
4109    E: std::fmt::Debug,
4110{
4111    type Ok = T;
4112
4113    #[track_caller]
4114    fn trace_err(self) -> Option<T> {
4115        match self {
4116            Ok(value) => Some(value),
4117            Err(error) => {
4118                tracing::error!("{:?}", error);
4119                None
4120            }
4121        }
4122    }
4123}