rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
  10    },
  11    executor::Executor,
  12};
  13use anyhow::{Context as _, anyhow, bail};
  14use async_tungstenite::tungstenite::{
  15    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  16};
  17use axum::headers::UserAgent;
  18use axum::{
  19    Extension, Router, TypedHeader,
  20    body::Body,
  21    extract::{
  22        ConnectInfo, WebSocketUpgrade,
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use core::fmt::{self, Debug, Formatter};
  34use futures::TryFutureExt as _;
  35use reqwest_client::ReqwestClient;
  36use rpc::proto::split_repository_update;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38use tracing::Span;
  39use util::paths::PathStyle;
  40
  41use futures::{
  42    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  43    stream::FuturesUnordered,
  44};
  45use prometheus::{IntGauge, register_int_gauge};
  46use rpc::{
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52};
  53use semantic_version::SemanticVersion;
  54use serde::{Serialize, Serializer};
  55use std::{
  56    any::TypeId,
  57    future::Future,
  58    marker::PhantomData,
  59    mem,
  60    net::SocketAddr,
  61    ops::{Deref, DerefMut},
  62    rc::Rc,
  63    sync::{
  64        Arc, OnceLock,
  65        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  66    },
  67    time::{Duration, Instant},
  68};
  69use tokio::sync::{Semaphore, watch};
  70use tower::ServiceBuilder;
  71use tracing::{
  72    Instrument,
  73    field::{self},
  74    info_span, instrument,
  75};
  76
  77pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  78
  79// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  80pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  81
  82const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  83const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  84
  85static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  86
  87const TOTAL_DURATION_MS: &str = "total_duration_ms";
  88const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  89const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  90const HOST_WAITING_MS: &str = "host_waiting_ms";
  91
  92type MessageHandler =
  93    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  94
  95pub struct ConnectionGuard;
  96
  97impl ConnectionGuard {
  98    pub fn try_acquire() -> Result<Self, ()> {
  99        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 100        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 101            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 102            tracing::error!(
 103                "too many concurrent connections: {}",
 104                current_connections + 1
 105            );
 106            return Err(());
 107        }
 108        Ok(ConnectionGuard)
 109    }
 110}
 111
 112impl Drop for ConnectionGuard {
 113    fn drop(&mut self) {
 114        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 115    }
 116}
 117
 118struct Response<R> {
 119    peer: Arc<Peer>,
 120    receipt: Receipt<R>,
 121    responded: Arc<AtomicBool>,
 122}
 123
 124impl<R: RequestMessage> Response<R> {
 125    fn send(self, payload: R::Response) -> Result<()> {
 126        self.responded.store(true, SeqCst);
 127        self.peer.respond(self.receipt, payload)?;
 128        Ok(())
 129    }
 130}
 131
 132#[derive(Clone, Debug)]
 133pub enum Principal {
 134    User(User),
 135    Impersonated { user: User, admin: User },
 136}
 137
 138impl Principal {
 139    fn update_span(&self, span: &tracing::Span) {
 140        match &self {
 141            Principal::User(user) => {
 142                span.record("user_id", user.id.0);
 143                span.record("login", &user.github_login);
 144            }
 145            Principal::Impersonated { user, admin } => {
 146                span.record("user_id", user.id.0);
 147                span.record("login", &user.github_login);
 148                span.record("impersonator", &admin.github_login);
 149            }
 150        }
 151    }
 152}
 153
 154#[derive(Clone)]
 155struct MessageContext {
 156    session: Session,
 157    span: tracing::Span,
 158}
 159
 160impl Deref for MessageContext {
 161    type Target = Session;
 162
 163    fn deref(&self) -> &Self::Target {
 164        &self.session
 165    }
 166}
 167
 168impl MessageContext {
 169    pub fn forward_request<T: RequestMessage>(
 170        &self,
 171        receiver_id: ConnectionId,
 172        request: T,
 173    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 174        let request_start_time = Instant::now();
 175        let span = self.span.clone();
 176        tracing::info!("start forwarding request");
 177        self.peer
 178            .forward_request(self.connection_id, receiver_id, request)
 179            .inspect(move |_| {
 180                span.record(
 181                    HOST_WAITING_MS,
 182                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 183                );
 184            })
 185            .inspect_err(|_| tracing::error!("error forwarding request"))
 186            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 187    }
 188}
 189
 190#[derive(Clone)]
 191struct Session {
 192    principal: Principal,
 193    connection_id: ConnectionId,
 194    db: Arc<tokio::sync::Mutex<DbHandle>>,
 195    peer: Arc<Peer>,
 196    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 197    app_state: Arc<AppState>,
 198    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 199    /// The GeoIP country code for the user.
 200    #[allow(unused)]
 201    geoip_country_code: Option<String>,
 202    #[allow(unused)]
 203    system_id: Option<String>,
 204    _executor: Executor,
 205}
 206
 207impl Session {
 208    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 209        #[cfg(test)]
 210        tokio::task::yield_now().await;
 211        let guard = self.db.lock().await;
 212        #[cfg(test)]
 213        tokio::task::yield_now().await;
 214        guard
 215    }
 216
 217    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 218        #[cfg(test)]
 219        tokio::task::yield_now().await;
 220        let guard = self.connection_pool.lock();
 221        ConnectionPoolGuard {
 222            guard,
 223            _not_send: PhantomData,
 224        }
 225    }
 226
 227    fn is_staff(&self) -> bool {
 228        match &self.principal {
 229            Principal::User(user) => user.admin,
 230            Principal::Impersonated { .. } => true,
 231        }
 232    }
 233
 234    fn user_id(&self) -> UserId {
 235        match &self.principal {
 236            Principal::User(user) => user.id,
 237            Principal::Impersonated { user, .. } => user.id,
 238        }
 239    }
 240
 241    pub fn email(&self) -> Option<String> {
 242        match &self.principal {
 243            Principal::User(user) => user.email_address.clone(),
 244            Principal::Impersonated { user, .. } => user.email_address.clone(),
 245        }
 246    }
 247}
 248
 249impl Debug for Session {
 250    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 251        let mut result = f.debug_struct("Session");
 252        match &self.principal {
 253            Principal::User(user) => {
 254                result.field("user", &user.github_login);
 255            }
 256            Principal::Impersonated { user, admin } => {
 257                result.field("user", &user.github_login);
 258                result.field("impersonator", &admin.github_login);
 259            }
 260        }
 261        result.field("connection_id", &self.connection_id).finish()
 262    }
 263}
 264
 265struct DbHandle(Arc<Database>);
 266
 267impl Deref for DbHandle {
 268    type Target = Database;
 269
 270    fn deref(&self) -> &Self::Target {
 271        self.0.as_ref()
 272    }
 273}
 274
 275pub struct Server {
 276    id: parking_lot::Mutex<ServerId>,
 277    peer: Arc<Peer>,
 278    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 279    app_state: Arc<AppState>,
 280    handlers: HashMap<TypeId, MessageHandler>,
 281    teardown: watch::Sender<bool>,
 282}
 283
 284pub(crate) struct ConnectionPoolGuard<'a> {
 285    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 286    _not_send: PhantomData<Rc<()>>,
 287}
 288
 289#[derive(Serialize)]
 290pub struct ServerSnapshot<'a> {
 291    peer: &'a Peer,
 292    #[serde(serialize_with = "serialize_deref")]
 293    connection_pool: ConnectionPoolGuard<'a>,
 294}
 295
 296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 297where
 298    S: Serializer,
 299    T: Deref<Target = U>,
 300    U: Serialize,
 301{
 302    Serialize::serialize(value.deref(), serializer)
 303}
 304
 305impl Server {
 306    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 307        let mut server = Self {
 308            id: parking_lot::Mutex::new(id),
 309            peer: Peer::new(id.0 as u32),
 310            app_state,
 311            connection_pool: Default::default(),
 312            handlers: Default::default(),
 313            teardown: watch::channel(false).0,
 314        };
 315
 316        server
 317            .add_request_handler(ping)
 318            .add_request_handler(create_room)
 319            .add_request_handler(join_room)
 320            .add_request_handler(rejoin_room)
 321            .add_request_handler(leave_room)
 322            .add_request_handler(set_room_participant_role)
 323            .add_request_handler(call)
 324            .add_request_handler(cancel_call)
 325            .add_message_handler(decline_call)
 326            .add_request_handler(update_participant_location)
 327            .add_request_handler(share_project)
 328            .add_message_handler(unshare_project)
 329            .add_request_handler(join_project)
 330            .add_message_handler(leave_project)
 331            .add_request_handler(update_project)
 332            .add_request_handler(update_worktree)
 333            .add_request_handler(update_repository)
 334            .add_request_handler(remove_repository)
 335            .add_message_handler(start_language_server)
 336            .add_message_handler(update_language_server)
 337            .add_message_handler(update_diagnostic_summary)
 338            .add_message_handler(update_worktree_settings)
 339            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 340            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 341            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 342            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 343            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 344            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 345            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 346            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 347            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 348            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 349            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 350            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 351            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 352            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 353            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 354            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 355            .add_request_handler(
 356                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 357            )
 358            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 359            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 360            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 361            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 362            .add_request_handler(
 363                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 364            )
 365            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 366            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 367            .add_request_handler(
 368                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 369            )
 370            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 371            .add_request_handler(
 372                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 373            )
 374            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 375            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 376            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 377            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 378            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 379            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 380            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 381            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 382            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 383            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 384            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 385            .add_request_handler(
 386                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 387            )
 388            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 389            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 390            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 391            .add_request_handler(lsp_query)
 392            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 393            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 394            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 395            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 396            .add_message_handler(create_buffer_for_peer)
 397            .add_request_handler(update_buffer)
 398            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 399            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 400            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 401            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 402            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 403            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 404            .add_message_handler(
 405                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 406            )
 407            .add_request_handler(get_users)
 408            .add_request_handler(fuzzy_search_users)
 409            .add_request_handler(request_contact)
 410            .add_request_handler(remove_contact)
 411            .add_request_handler(respond_to_contact_request)
 412            .add_message_handler(subscribe_to_channels)
 413            .add_request_handler(create_channel)
 414            .add_request_handler(delete_channel)
 415            .add_request_handler(invite_channel_member)
 416            .add_request_handler(remove_channel_member)
 417            .add_request_handler(set_channel_member_role)
 418            .add_request_handler(set_channel_visibility)
 419            .add_request_handler(rename_channel)
 420            .add_request_handler(join_channel_buffer)
 421            .add_request_handler(leave_channel_buffer)
 422            .add_message_handler(update_channel_buffer)
 423            .add_request_handler(rejoin_channel_buffers)
 424            .add_request_handler(get_channel_members)
 425            .add_request_handler(respond_to_channel_invite)
 426            .add_request_handler(join_channel)
 427            .add_request_handler(join_channel_chat)
 428            .add_message_handler(leave_channel_chat)
 429            .add_request_handler(send_channel_message)
 430            .add_request_handler(remove_channel_message)
 431            .add_request_handler(update_channel_message)
 432            .add_request_handler(get_channel_messages)
 433            .add_request_handler(get_channel_messages_by_id)
 434            .add_request_handler(get_notifications)
 435            .add_request_handler(mark_notification_as_read)
 436            .add_request_handler(move_channel)
 437            .add_request_handler(reorder_channel)
 438            .add_request_handler(follow)
 439            .add_message_handler(unfollow)
 440            .add_message_handler(update_followers)
 441            .add_message_handler(acknowledge_channel_message)
 442            .add_message_handler(acknowledge_buffer_version)
 443            .add_request_handler(get_supermaven_api_key)
 444            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 445            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 446            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 447            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 448            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 449            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 450            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 451            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 452            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 453            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 454            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 455            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 456            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 457            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 458            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 459            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 460            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 461            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 462            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 463            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 464            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 465            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 466            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 467            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 468            .add_message_handler(update_context)
 469            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 470            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
 471
 472        Arc::new(server)
 473    }
 474
 475    pub async fn start(&self) -> Result<()> {
 476        let server_id = *self.id.lock();
 477        let app_state = self.app_state.clone();
 478        let peer = self.peer.clone();
 479        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 480        let pool = self.connection_pool.clone();
 481        let livekit_client = self.app_state.livekit_client.clone();
 482
 483        let span = info_span!("start server");
 484        self.app_state.executor.spawn_detached(
 485            async move {
 486                tracing::info!("waiting for cleanup timeout");
 487                timeout.await;
 488                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 489
 490                app_state
 491                    .db
 492                    .delete_stale_channel_chat_participants(
 493                        &app_state.config.zed_environment,
 494                        server_id,
 495                    )
 496                    .await
 497                    .trace_err();
 498
 499                if let Some((room_ids, channel_ids)) = app_state
 500                    .db
 501                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 502                    .await
 503                    .trace_err()
 504                {
 505                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 506                    tracing::info!(
 507                        stale_channel_buffer_count = channel_ids.len(),
 508                        "retrieved stale channel buffers"
 509                    );
 510
 511                    for channel_id in channel_ids {
 512                        if let Some(refreshed_channel_buffer) = app_state
 513                            .db
 514                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 515                            .await
 516                            .trace_err()
 517                        {
 518                            for connection_id in refreshed_channel_buffer.connection_ids {
 519                                peer.send(
 520                                    connection_id,
 521                                    proto::UpdateChannelBufferCollaborators {
 522                                        channel_id: channel_id.to_proto(),
 523                                        collaborators: refreshed_channel_buffer
 524                                            .collaborators
 525                                            .clone(),
 526                                    },
 527                                )
 528                                .trace_err();
 529                            }
 530                        }
 531                    }
 532
 533                    for room_id in room_ids {
 534                        let mut contacts_to_update = HashSet::default();
 535                        let mut canceled_calls_to_user_ids = Vec::new();
 536                        let mut livekit_room = String::new();
 537                        let mut delete_livekit_room = false;
 538
 539                        if let Some(mut refreshed_room) = app_state
 540                            .db
 541                            .clear_stale_room_participants(room_id, server_id)
 542                            .await
 543                            .trace_err()
 544                        {
 545                            tracing::info!(
 546                                room_id = room_id.0,
 547                                new_participant_count = refreshed_room.room.participants.len(),
 548                                "refreshed room"
 549                            );
 550                            room_updated(&refreshed_room.room, &peer);
 551                            if let Some(channel) = refreshed_room.channel.as_ref() {
 552                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 553                            }
 554                            contacts_to_update
 555                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 556                            contacts_to_update
 557                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 558                            canceled_calls_to_user_ids =
 559                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 560                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 561                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 562                        }
 563
 564                        {
 565                            let pool = pool.lock();
 566                            for canceled_user_id in canceled_calls_to_user_ids {
 567                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 568                                    peer.send(
 569                                        connection_id,
 570                                        proto::CallCanceled {
 571                                            room_id: room_id.to_proto(),
 572                                        },
 573                                    )
 574                                    .trace_err();
 575                                }
 576                            }
 577                        }
 578
 579                        for user_id in contacts_to_update {
 580                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 581                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 582                            if let Some((busy, contacts)) = busy.zip(contacts) {
 583                                let pool = pool.lock();
 584                                let updated_contact = contact_for_user(user_id, busy, &pool);
 585                                for contact in contacts {
 586                                    if let db::Contact::Accepted {
 587                                        user_id: contact_user_id,
 588                                        ..
 589                                    } = contact
 590                                    {
 591                                        for contact_conn_id in
 592                                            pool.user_connection_ids(contact_user_id)
 593                                        {
 594                                            peer.send(
 595                                                contact_conn_id,
 596                                                proto::UpdateContacts {
 597                                                    contacts: vec![updated_contact.clone()],
 598                                                    remove_contacts: Default::default(),
 599                                                    incoming_requests: Default::default(),
 600                                                    remove_incoming_requests: Default::default(),
 601                                                    outgoing_requests: Default::default(),
 602                                                    remove_outgoing_requests: Default::default(),
 603                                                },
 604                                            )
 605                                            .trace_err();
 606                                        }
 607                                    }
 608                                }
 609                            }
 610                        }
 611
 612                        if let Some(live_kit) = livekit_client.as_ref()
 613                            && delete_livekit_room
 614                        {
 615                            live_kit.delete_room(livekit_room).await.trace_err();
 616                        }
 617                    }
 618                }
 619
 620                app_state
 621                    .db
 622                    .delete_stale_channel_chat_participants(
 623                        &app_state.config.zed_environment,
 624                        server_id,
 625                    )
 626                    .await
 627                    .trace_err();
 628
 629                app_state
 630                    .db
 631                    .clear_old_worktree_entries(server_id)
 632                    .await
 633                    .trace_err();
 634
 635                app_state
 636                    .db
 637                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 638                    .await
 639                    .trace_err();
 640            }
 641            .instrument(span),
 642        );
 643        Ok(())
 644    }
 645
 646    pub fn teardown(&self) {
 647        self.peer.teardown();
 648        self.connection_pool.lock().reset();
 649        let _ = self.teardown.send(true);
 650    }
 651
 652    #[cfg(test)]
 653    pub fn reset(&self, id: ServerId) {
 654        self.teardown();
 655        *self.id.lock() = id;
 656        self.peer.reset(id.0 as u32);
 657        let _ = self.teardown.send(false);
 658    }
 659
 660    #[cfg(test)]
 661    pub fn id(&self) -> ServerId {
 662        *self.id.lock()
 663    }
 664
 665    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 666    where
 667        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 668        Fut: 'static + Send + Future<Output = Result<()>>,
 669        M: EnvelopedMessage,
 670    {
 671        let prev_handler = self.handlers.insert(
 672            TypeId::of::<M>(),
 673            Box::new(move |envelope, session, span| {
 674                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 675                let received_at = envelope.received_at;
 676                tracing::info!("message received");
 677                let start_time = Instant::now();
 678                let future = (handler)(
 679                    *envelope,
 680                    MessageContext {
 681                        session,
 682                        span: span.clone(),
 683                    },
 684                );
 685                async move {
 686                    let result = future.await;
 687                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 688                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 689                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 690                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 691                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 692                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 693                    match result {
 694                        Err(error) => {
 695                            tracing::error!(?error, "error handling message")
 696                        }
 697                        Ok(()) => tracing::info!("finished handling message"),
 698                    }
 699                }
 700                .boxed()
 701            }),
 702        );
 703        if prev_handler.is_some() {
 704            panic!("registered a handler for the same message twice");
 705        }
 706        self
 707    }
 708
 709    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 710    where
 711        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 712        Fut: 'static + Send + Future<Output = Result<()>>,
 713        M: EnvelopedMessage,
 714    {
 715        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 716        self
 717    }
 718
 719    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 720    where
 721        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 722        Fut: Send + Future<Output = Result<()>>,
 723        M: RequestMessage,
 724    {
 725        let handler = Arc::new(handler);
 726        self.add_handler(move |envelope, session| {
 727            let receipt = envelope.receipt();
 728            let handler = handler.clone();
 729            async move {
 730                let peer = session.peer.clone();
 731                let responded = Arc::new(AtomicBool::default());
 732                let response = Response {
 733                    peer: peer.clone(),
 734                    responded: responded.clone(),
 735                    receipt,
 736                };
 737                match (handler)(envelope.payload, response, session).await {
 738                    Ok(()) => {
 739                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 740                            Ok(())
 741                        } else {
 742                            Err(anyhow!("handler did not send a response"))?
 743                        }
 744                    }
 745                    Err(error) => {
 746                        let proto_err = match &error {
 747                            Error::Internal(err) => err.to_proto(),
 748                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 749                        };
 750                        peer.respond_with_error(receipt, proto_err)?;
 751                        Err(error)
 752                    }
 753                }
 754            }
 755        })
 756    }
 757
 758    pub fn handle_connection(
 759        self: &Arc<Self>,
 760        connection: Connection,
 761        address: String,
 762        principal: Principal,
 763        zed_version: ZedVersion,
 764        release_channel: Option<String>,
 765        user_agent: Option<String>,
 766        geoip_country_code: Option<String>,
 767        system_id: Option<String>,
 768        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 769        executor: Executor,
 770        connection_guard: Option<ConnectionGuard>,
 771    ) -> impl Future<Output = ()> + use<> {
 772        let this = self.clone();
 773        let span = info_span!("handle connection", %address,
 774            connection_id=field::Empty,
 775            user_id=field::Empty,
 776            login=field::Empty,
 777            impersonator=field::Empty,
 778            user_agent=field::Empty,
 779            geoip_country_code=field::Empty,
 780            release_channel=field::Empty,
 781        );
 782        principal.update_span(&span);
 783        if let Some(user_agent) = user_agent {
 784            span.record("user_agent", user_agent);
 785        }
 786        if let Some(release_channel) = release_channel {
 787            span.record("release_channel", release_channel);
 788        }
 789
 790        if let Some(country_code) = geoip_country_code.as_ref() {
 791            span.record("geoip_country_code", country_code);
 792        }
 793
 794        let mut teardown = self.teardown.subscribe();
 795        async move {
 796            if *teardown.borrow() {
 797                tracing::error!("server is tearing down");
 798                return;
 799            }
 800
 801            let (connection_id, handle_io, mut incoming_rx) =
 802                this.peer.add_connection(connection, {
 803                    let executor = executor.clone();
 804                    move |duration| executor.sleep(duration)
 805                });
 806            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 807
 808            tracing::info!("connection opened");
 809
 810            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 811            let http_client = match ReqwestClient::user_agent(&user_agent) {
 812                Ok(http_client) => Arc::new(http_client),
 813                Err(error) => {
 814                    tracing::error!(?error, "failed to create HTTP client");
 815                    return;
 816                }
 817            };
 818
 819            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 820                |supermaven_admin_api_key| {
 821                    Arc::new(SupermavenAdminApi::new(
 822                        supermaven_admin_api_key.to_string(),
 823                        http_client.clone(),
 824                    ))
 825                },
 826            );
 827
 828            let session = Session {
 829                principal: principal.clone(),
 830                connection_id,
 831                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 832                peer: this.peer.clone(),
 833                connection_pool: this.connection_pool.clone(),
 834                app_state: this.app_state.clone(),
 835                geoip_country_code,
 836                system_id,
 837                _executor: executor.clone(),
 838                supermaven_client,
 839            };
 840
 841            if let Err(error) = this
 842                .send_initial_client_update(
 843                    connection_id,
 844                    zed_version,
 845                    send_connection_id,
 846                    &session,
 847                )
 848                .await
 849            {
 850                tracing::error!(?error, "failed to send initial client update");
 851                return;
 852            }
 853            drop(connection_guard);
 854
 855            let handle_io = handle_io.fuse();
 856            futures::pin_mut!(handle_io);
 857
 858            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 859            // This prevents deadlocks when e.g., client A performs a request to client B and
 860            // client B performs a request to client A. If both clients stop processing further
 861            // messages until their respective request completes, they won't have a chance to
 862            // respond to the other client's request and cause a deadlock.
 863            //
 864            // This arrangement ensures we will attempt to process earlier messages first, but fall
 865            // back to processing messages arrived later in the spirit of making progress.
 866            const MAX_CONCURRENT_HANDLERS: usize = 256;
 867            let mut foreground_message_handlers = FuturesUnordered::new();
 868            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 869            let get_concurrent_handlers = {
 870                let concurrent_handlers = concurrent_handlers.clone();
 871                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 872            };
 873            loop {
 874                let next_message = async {
 875                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 876                    let message = incoming_rx.next().await;
 877                    // Cache the concurrent_handlers here, so that we know what the
 878                    // queue looks like as each handler starts
 879                    (permit, message, get_concurrent_handlers())
 880                }
 881                .fuse();
 882                futures::pin_mut!(next_message);
 883                futures::select_biased! {
 884                    _ = teardown.changed().fuse() => return,
 885                    result = handle_io => {
 886                        if let Err(error) = result {
 887                            tracing::error!(?error, "error handling I/O");
 888                        }
 889                        break;
 890                    }
 891                    _ = foreground_message_handlers.next() => {}
 892                    next_message = next_message => {
 893                        let (permit, message, concurrent_handlers) = next_message;
 894                        if let Some(message) = message {
 895                            let type_name = message.payload_type_name();
 896                            // note: we copy all the fields from the parent span so we can query them in the logs.
 897                            // (https://github.com/tokio-rs/tracing/issues/2670).
 898                            let span = tracing::info_span!("receive message",
 899                                %connection_id,
 900                                %address,
 901                                type_name,
 902                                concurrent_handlers,
 903                                user_id=field::Empty,
 904                                login=field::Empty,
 905                                impersonator=field::Empty,
 906                                lsp_query_request=field::Empty,
 907                                release_channel=field::Empty,
 908                                { TOTAL_DURATION_MS }=field::Empty,
 909                                { PROCESSING_DURATION_MS }=field::Empty,
 910                                { QUEUE_DURATION_MS }=field::Empty,
 911                                { HOST_WAITING_MS }=field::Empty
 912                            );
 913                            principal.update_span(&span);
 914                            let span_enter = span.enter();
 915                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 916                                let is_background = message.is_background();
 917                                let handle_message = (handler)(message, session.clone(), span.clone());
 918                                drop(span_enter);
 919
 920                                let handle_message = async move {
 921                                    handle_message.await;
 922                                    drop(permit);
 923                                }.instrument(span);
 924                                if is_background {
 925                                    executor.spawn_detached(handle_message);
 926                                } else {
 927                                    foreground_message_handlers.push(handle_message);
 928                                }
 929                            } else {
 930                                tracing::error!("no message handler");
 931                            }
 932                        } else {
 933                            tracing::info!("connection closed");
 934                            break;
 935                        }
 936                    }
 937                }
 938            }
 939
 940            drop(foreground_message_handlers);
 941            let concurrent_handlers = get_concurrent_handlers();
 942            tracing::info!(concurrent_handlers, "signing out");
 943            if let Err(error) = connection_lost(session, teardown, executor).await {
 944                tracing::error!(?error, "error signing out");
 945            }
 946        }
 947        .instrument(span)
 948    }
 949
 950    async fn send_initial_client_update(
 951        &self,
 952        connection_id: ConnectionId,
 953        zed_version: ZedVersion,
 954        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 955        session: &Session,
 956    ) -> Result<()> {
 957        self.peer.send(
 958            connection_id,
 959            proto::Hello {
 960                peer_id: Some(connection_id.into()),
 961            },
 962        )?;
 963        tracing::info!("sent hello message");
 964        if let Some(send_connection_id) = send_connection_id.take() {
 965            let _ = send_connection_id.send(connection_id);
 966        }
 967
 968        match &session.principal {
 969            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 970                if !user.connected_once {
 971                    self.peer.send(connection_id, proto::ShowContacts {})?;
 972                    self.app_state
 973                        .db
 974                        .set_user_connected_once(user.id, true)
 975                        .await?;
 976                }
 977
 978                let contacts = self.app_state.db.get_contacts(user.id).await?;
 979
 980                {
 981                    let mut pool = self.connection_pool.lock();
 982                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 983                    self.peer.send(
 984                        connection_id,
 985                        build_initial_contacts_update(contacts, &pool),
 986                    )?;
 987                }
 988
 989                if should_auto_subscribe_to_channels(zed_version) {
 990                    subscribe_user_to_channels(user.id, session).await?;
 991                }
 992
 993                if let Some(incoming_call) =
 994                    self.app_state.db.incoming_call_for_user(user.id).await?
 995                {
 996                    self.peer.send(connection_id, incoming_call)?;
 997                }
 998
 999                update_user_contacts(user.id, session).await?;
1000            }
1001        }
1002
1003        Ok(())
1004    }
1005
1006    pub async fn invite_code_redeemed(
1007        self: &Arc<Self>,
1008        inviter_id: UserId,
1009        invitee_id: UserId,
1010    ) -> Result<()> {
1011        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1012            && let Some(code) = &user.invite_code
1013        {
1014            let pool = self.connection_pool.lock();
1015            let invitee_contact = contact_for_user(invitee_id, false, &pool);
1016            for connection_id in pool.user_connection_ids(inviter_id) {
1017                self.peer.send(
1018                    connection_id,
1019                    proto::UpdateContacts {
1020                        contacts: vec![invitee_contact.clone()],
1021                        ..Default::default()
1022                    },
1023                )?;
1024                self.peer.send(
1025                    connection_id,
1026                    proto::UpdateInviteInfo {
1027                        url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1028                        count: user.invite_count as u32,
1029                    },
1030                )?;
1031            }
1032        }
1033        Ok(())
1034    }
1035
1036    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1037        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1038            && let Some(invite_code) = &user.invite_code
1039        {
1040            let pool = self.connection_pool.lock();
1041            for connection_id in pool.user_connection_ids(user_id) {
1042                self.peer.send(
1043                    connection_id,
1044                    proto::UpdateInviteInfo {
1045                        url: format!(
1046                            "{}{}",
1047                            self.app_state.config.invite_link_prefix, invite_code
1048                        ),
1049                        count: user.invite_count as u32,
1050                    },
1051                )?;
1052            }
1053        }
1054        Ok(())
1055    }
1056
1057    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1058        ServerSnapshot {
1059            connection_pool: ConnectionPoolGuard {
1060                guard: self.connection_pool.lock(),
1061                _not_send: PhantomData,
1062            },
1063            peer: &self.peer,
1064        }
1065    }
1066}
1067
1068impl Deref for ConnectionPoolGuard<'_> {
1069    type Target = ConnectionPool;
1070
1071    fn deref(&self) -> &Self::Target {
1072        &self.guard
1073    }
1074}
1075
1076impl DerefMut for ConnectionPoolGuard<'_> {
1077    fn deref_mut(&mut self) -> &mut Self::Target {
1078        &mut self.guard
1079    }
1080}
1081
1082impl Drop for ConnectionPoolGuard<'_> {
1083    fn drop(&mut self) {
1084        #[cfg(test)]
1085        self.check_invariants();
1086    }
1087}
1088
1089fn broadcast<F>(
1090    sender_id: Option<ConnectionId>,
1091    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1092    mut f: F,
1093) where
1094    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1095{
1096    for receiver_id in receiver_ids {
1097        if Some(receiver_id) != sender_id
1098            && let Err(error) = f(receiver_id)
1099        {
1100            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1101        }
1102    }
1103}
1104
1105pub struct ProtocolVersion(u32);
1106
1107impl Header for ProtocolVersion {
1108    fn name() -> &'static HeaderName {
1109        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1110        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1111    }
1112
1113    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1114    where
1115        Self: Sized,
1116        I: Iterator<Item = &'i axum::http::HeaderValue>,
1117    {
1118        let version = values
1119            .next()
1120            .ok_or_else(axum::headers::Error::invalid)?
1121            .to_str()
1122            .map_err(|_| axum::headers::Error::invalid())?
1123            .parse()
1124            .map_err(|_| axum::headers::Error::invalid())?;
1125        Ok(Self(version))
1126    }
1127
1128    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1129        values.extend([self.0.to_string().parse().unwrap()]);
1130    }
1131}
1132
1133pub struct AppVersionHeader(SemanticVersion);
1134impl Header for AppVersionHeader {
1135    fn name() -> &'static HeaderName {
1136        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1137        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1138    }
1139
1140    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1141    where
1142        Self: Sized,
1143        I: Iterator<Item = &'i axum::http::HeaderValue>,
1144    {
1145        let version = values
1146            .next()
1147            .ok_or_else(axum::headers::Error::invalid)?
1148            .to_str()
1149            .map_err(|_| axum::headers::Error::invalid())?
1150            .parse()
1151            .map_err(|_| axum::headers::Error::invalid())?;
1152        Ok(Self(version))
1153    }
1154
1155    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1156        values.extend([self.0.to_string().parse().unwrap()]);
1157    }
1158}
1159
1160#[derive(Debug)]
1161pub struct ReleaseChannelHeader(String);
1162
1163impl Header for ReleaseChannelHeader {
1164    fn name() -> &'static HeaderName {
1165        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1166        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1167    }
1168
1169    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1170    where
1171        Self: Sized,
1172        I: Iterator<Item = &'i axum::http::HeaderValue>,
1173    {
1174        Ok(Self(
1175            values
1176                .next()
1177                .ok_or_else(axum::headers::Error::invalid)?
1178                .to_str()
1179                .map_err(|_| axum::headers::Error::invalid())?
1180                .to_owned(),
1181        ))
1182    }
1183
1184    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1185        values.extend([self.0.parse().unwrap()]);
1186    }
1187}
1188
1189pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1190    Router::new()
1191        .route("/rpc", get(handle_websocket_request))
1192        .layer(
1193            ServiceBuilder::new()
1194                .layer(Extension(server.app_state.clone()))
1195                .layer(middleware::from_fn(auth::validate_header)),
1196        )
1197        .route("/metrics", get(handle_metrics))
1198        .layer(Extension(server))
1199}
1200
1201pub async fn handle_websocket_request(
1202    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1203    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1204    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1205    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1206    Extension(server): Extension<Arc<Server>>,
1207    Extension(principal): Extension<Principal>,
1208    user_agent: Option<TypedHeader<UserAgent>>,
1209    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1210    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1211    ws: WebSocketUpgrade,
1212) -> axum::response::Response {
1213    if protocol_version != rpc::PROTOCOL_VERSION {
1214        return (
1215            StatusCode::UPGRADE_REQUIRED,
1216            "client must be upgraded".to_string(),
1217        )
1218            .into_response();
1219    }
1220
1221    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1222        return (
1223            StatusCode::UPGRADE_REQUIRED,
1224            "no version header found".to_string(),
1225        )
1226            .into_response();
1227    };
1228
1229    let release_channel = release_channel_header.map(|header| header.0.0);
1230
1231    if !version.can_collaborate() {
1232        return (
1233            StatusCode::UPGRADE_REQUIRED,
1234            "client must be upgraded".to_string(),
1235        )
1236            .into_response();
1237    }
1238
1239    let socket_address = socket_address.to_string();
1240
1241    // Acquire connection guard before WebSocket upgrade
1242    let connection_guard = match ConnectionGuard::try_acquire() {
1243        Ok(guard) => guard,
1244        Err(()) => {
1245            return (
1246                StatusCode::SERVICE_UNAVAILABLE,
1247                "Too many concurrent connections",
1248            )
1249                .into_response();
1250        }
1251    };
1252
1253    ws.on_upgrade(move |socket| {
1254        let socket = socket
1255            .map_ok(to_tungstenite_message)
1256            .err_into()
1257            .with(|message| async move { to_axum_message(message) });
1258        let connection = Connection::new(Box::pin(socket));
1259        async move {
1260            server
1261                .handle_connection(
1262                    connection,
1263                    socket_address,
1264                    principal,
1265                    version,
1266                    release_channel,
1267                    user_agent.map(|header| header.to_string()),
1268                    country_code_header.map(|header| header.to_string()),
1269                    system_id_header.map(|header| header.to_string()),
1270                    None,
1271                    Executor::Production,
1272                    Some(connection_guard),
1273                )
1274                .await;
1275        }
1276    })
1277}
1278
1279pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1280    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1281    let connections_metric = CONNECTIONS_METRIC
1282        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1283
1284    let connections = server
1285        .connection_pool
1286        .lock()
1287        .connections()
1288        .filter(|connection| !connection.admin)
1289        .count();
1290    connections_metric.set(connections as _);
1291
1292    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1293    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1294        register_int_gauge!(
1295            "shared_projects",
1296            "number of open projects with one or more guests"
1297        )
1298        .unwrap()
1299    });
1300
1301    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1302    shared_projects_metric.set(shared_projects as _);
1303
1304    let encoder = prometheus::TextEncoder::new();
1305    let metric_families = prometheus::gather();
1306    let encoded_metrics = encoder
1307        .encode_to_string(&metric_families)
1308        .map_err(|err| anyhow!("{err}"))?;
1309    Ok(encoded_metrics)
1310}
1311
1312#[instrument(err, skip(executor))]
1313async fn connection_lost(
1314    session: Session,
1315    mut teardown: watch::Receiver<bool>,
1316    executor: Executor,
1317) -> Result<()> {
1318    session.peer.disconnect(session.connection_id);
1319    session
1320        .connection_pool()
1321        .await
1322        .remove_connection(session.connection_id)?;
1323
1324    session
1325        .db()
1326        .await
1327        .connection_lost(session.connection_id)
1328        .await
1329        .trace_err();
1330
1331    futures::select_biased! {
1332        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1333
1334            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1335            leave_room_for_session(&session, session.connection_id).await.trace_err();
1336            leave_channel_buffers_for_session(&session)
1337                .await
1338                .trace_err();
1339
1340            if !session
1341                .connection_pool()
1342                .await
1343                .is_user_online(session.user_id())
1344            {
1345                let db = session.db().await;
1346                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1347                    room_updated(&room, &session.peer);
1348                }
1349            }
1350
1351            update_user_contacts(session.user_id(), &session).await?;
1352        },
1353        _ = teardown.changed().fuse() => {}
1354    }
1355
1356    Ok(())
1357}
1358
1359/// Acknowledges a ping from a client, used to keep the connection alive.
1360async fn ping(
1361    _: proto::Ping,
1362    response: Response<proto::Ping>,
1363    _session: MessageContext,
1364) -> Result<()> {
1365    response.send(proto::Ack {})?;
1366    Ok(())
1367}
1368
1369/// Creates a new room for calling (outside of channels)
1370async fn create_room(
1371    _request: proto::CreateRoom,
1372    response: Response<proto::CreateRoom>,
1373    session: MessageContext,
1374) -> Result<()> {
1375    let livekit_room = nanoid::nanoid!(30);
1376
1377    let live_kit_connection_info = util::maybe!(async {
1378        let live_kit = session.app_state.livekit_client.as_ref();
1379        let live_kit = live_kit?;
1380        let user_id = session.user_id().to_string();
1381
1382        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1383
1384        Some(proto::LiveKitConnectionInfo {
1385            server_url: live_kit.url().into(),
1386            token,
1387            can_publish: true,
1388        })
1389    })
1390    .await;
1391
1392    let room = session
1393        .db()
1394        .await
1395        .create_room(session.user_id(), session.connection_id, &livekit_room)
1396        .await?;
1397
1398    response.send(proto::CreateRoomResponse {
1399        room: Some(room.clone()),
1400        live_kit_connection_info,
1401    })?;
1402
1403    update_user_contacts(session.user_id(), &session).await?;
1404    Ok(())
1405}
1406
1407/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1408async fn join_room(
1409    request: proto::JoinRoom,
1410    response: Response<proto::JoinRoom>,
1411    session: MessageContext,
1412) -> Result<()> {
1413    let room_id = RoomId::from_proto(request.id);
1414
1415    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1416
1417    if let Some(channel_id) = channel_id {
1418        return join_channel_internal(channel_id, Box::new(response), session).await;
1419    }
1420
1421    let joined_room = {
1422        let room = session
1423            .db()
1424            .await
1425            .join_room(room_id, session.user_id(), session.connection_id)
1426            .await?;
1427        room_updated(&room.room, &session.peer);
1428        room.into_inner()
1429    };
1430
1431    for connection_id in session
1432        .connection_pool()
1433        .await
1434        .user_connection_ids(session.user_id())
1435    {
1436        session
1437            .peer
1438            .send(
1439                connection_id,
1440                proto::CallCanceled {
1441                    room_id: room_id.to_proto(),
1442                },
1443            )
1444            .trace_err();
1445    }
1446
1447    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1448    {
1449        live_kit
1450            .room_token(
1451                &joined_room.room.livekit_room,
1452                &session.user_id().to_string(),
1453            )
1454            .trace_err()
1455            .map(|token| proto::LiveKitConnectionInfo {
1456                server_url: live_kit.url().into(),
1457                token,
1458                can_publish: true,
1459            })
1460    } else {
1461        None
1462    };
1463
1464    response.send(proto::JoinRoomResponse {
1465        room: Some(joined_room.room),
1466        channel_id: None,
1467        live_kit_connection_info,
1468    })?;
1469
1470    update_user_contacts(session.user_id(), &session).await?;
1471    Ok(())
1472}
1473
1474/// Rejoin room is used to reconnect to a room after connection errors.
1475async fn rejoin_room(
1476    request: proto::RejoinRoom,
1477    response: Response<proto::RejoinRoom>,
1478    session: MessageContext,
1479) -> Result<()> {
1480    let room;
1481    let channel;
1482    {
1483        let mut rejoined_room = session
1484            .db()
1485            .await
1486            .rejoin_room(request, session.user_id(), session.connection_id)
1487            .await?;
1488
1489        response.send(proto::RejoinRoomResponse {
1490            room: Some(rejoined_room.room.clone()),
1491            reshared_projects: rejoined_room
1492                .reshared_projects
1493                .iter()
1494                .map(|project| proto::ResharedProject {
1495                    id: project.id.to_proto(),
1496                    collaborators: project
1497                        .collaborators
1498                        .iter()
1499                        .map(|collaborator| collaborator.to_proto())
1500                        .collect(),
1501                })
1502                .collect(),
1503            rejoined_projects: rejoined_room
1504                .rejoined_projects
1505                .iter()
1506                .map(|rejoined_project| rejoined_project.to_proto())
1507                .collect(),
1508        })?;
1509        room_updated(&rejoined_room.room, &session.peer);
1510
1511        for project in &rejoined_room.reshared_projects {
1512            for collaborator in &project.collaborators {
1513                session
1514                    .peer
1515                    .send(
1516                        collaborator.connection_id,
1517                        proto::UpdateProjectCollaborator {
1518                            project_id: project.id.to_proto(),
1519                            old_peer_id: Some(project.old_connection_id.into()),
1520                            new_peer_id: Some(session.connection_id.into()),
1521                        },
1522                    )
1523                    .trace_err();
1524            }
1525
1526            broadcast(
1527                Some(session.connection_id),
1528                project
1529                    .collaborators
1530                    .iter()
1531                    .map(|collaborator| collaborator.connection_id),
1532                |connection_id| {
1533                    session.peer.forward_send(
1534                        session.connection_id,
1535                        connection_id,
1536                        proto::UpdateProject {
1537                            project_id: project.id.to_proto(),
1538                            worktrees: project.worktrees.clone(),
1539                        },
1540                    )
1541                },
1542            );
1543        }
1544
1545        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1546
1547        let rejoined_room = rejoined_room.into_inner();
1548
1549        room = rejoined_room.room;
1550        channel = rejoined_room.channel;
1551    }
1552
1553    if let Some(channel) = channel {
1554        channel_updated(
1555            &channel,
1556            &room,
1557            &session.peer,
1558            &*session.connection_pool().await,
1559        );
1560    }
1561
1562    update_user_contacts(session.user_id(), &session).await?;
1563    Ok(())
1564}
1565
1566fn notify_rejoined_projects(
1567    rejoined_projects: &mut Vec<RejoinedProject>,
1568    session: &Session,
1569) -> Result<()> {
1570    for project in rejoined_projects.iter() {
1571        for collaborator in &project.collaborators {
1572            session
1573                .peer
1574                .send(
1575                    collaborator.connection_id,
1576                    proto::UpdateProjectCollaborator {
1577                        project_id: project.id.to_proto(),
1578                        old_peer_id: Some(project.old_connection_id.into()),
1579                        new_peer_id: Some(session.connection_id.into()),
1580                    },
1581                )
1582                .trace_err();
1583        }
1584    }
1585
1586    for project in rejoined_projects {
1587        for worktree in mem::take(&mut project.worktrees) {
1588            // Stream this worktree's entries.
1589            let message = proto::UpdateWorktree {
1590                project_id: project.id.to_proto(),
1591                worktree_id: worktree.id,
1592                abs_path: worktree.abs_path.clone(),
1593                root_name: worktree.root_name,
1594                updated_entries: worktree.updated_entries,
1595                removed_entries: worktree.removed_entries,
1596                scan_id: worktree.scan_id,
1597                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1598                updated_repositories: worktree.updated_repositories,
1599                removed_repositories: worktree.removed_repositories,
1600            };
1601            for update in proto::split_worktree_update(message) {
1602                session.peer.send(session.connection_id, update)?;
1603            }
1604
1605            // Stream this worktree's diagnostics.
1606            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1607            if let Some(summary) = worktree_diagnostics.next() {
1608                let message = proto::UpdateDiagnosticSummary {
1609                    project_id: project.id.to_proto(),
1610                    worktree_id: worktree.id,
1611                    summary: Some(summary),
1612                    more_summaries: worktree_diagnostics.collect(),
1613                };
1614                session.peer.send(session.connection_id, message)?;
1615            }
1616
1617            for settings_file in worktree.settings_files {
1618                session.peer.send(
1619                    session.connection_id,
1620                    proto::UpdateWorktreeSettings {
1621                        project_id: project.id.to_proto(),
1622                        worktree_id: worktree.id,
1623                        path: settings_file.path,
1624                        content: Some(settings_file.content),
1625                        kind: Some(settings_file.kind.to_proto().into()),
1626                    },
1627                )?;
1628            }
1629        }
1630
1631        for repository in mem::take(&mut project.updated_repositories) {
1632            for update in split_repository_update(repository) {
1633                session.peer.send(session.connection_id, update)?;
1634            }
1635        }
1636
1637        for id in mem::take(&mut project.removed_repositories) {
1638            session.peer.send(
1639                session.connection_id,
1640                proto::RemoveRepository {
1641                    project_id: project.id.to_proto(),
1642                    id,
1643                },
1644            )?;
1645        }
1646    }
1647
1648    Ok(())
1649}
1650
1651/// leave room disconnects from the room.
1652async fn leave_room(
1653    _: proto::LeaveRoom,
1654    response: Response<proto::LeaveRoom>,
1655    session: MessageContext,
1656) -> Result<()> {
1657    leave_room_for_session(&session, session.connection_id).await?;
1658    response.send(proto::Ack {})?;
1659    Ok(())
1660}
1661
1662/// Updates the permissions of someone else in the room.
1663async fn set_room_participant_role(
1664    request: proto::SetRoomParticipantRole,
1665    response: Response<proto::SetRoomParticipantRole>,
1666    session: MessageContext,
1667) -> Result<()> {
1668    let user_id = UserId::from_proto(request.user_id);
1669    let role = ChannelRole::from(request.role());
1670
1671    let (livekit_room, can_publish) = {
1672        let room = session
1673            .db()
1674            .await
1675            .set_room_participant_role(
1676                session.user_id(),
1677                RoomId::from_proto(request.room_id),
1678                user_id,
1679                role,
1680            )
1681            .await?;
1682
1683        let livekit_room = room.livekit_room.clone();
1684        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1685        room_updated(&room, &session.peer);
1686        (livekit_room, can_publish)
1687    };
1688
1689    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1690        live_kit
1691            .update_participant(
1692                livekit_room.clone(),
1693                request.user_id.to_string(),
1694                livekit_api::proto::ParticipantPermission {
1695                    can_subscribe: true,
1696                    can_publish,
1697                    can_publish_data: can_publish,
1698                    hidden: false,
1699                    recorder: false,
1700                },
1701            )
1702            .await
1703            .trace_err();
1704    }
1705
1706    response.send(proto::Ack {})?;
1707    Ok(())
1708}
1709
1710/// Call someone else into the current room
1711async fn call(
1712    request: proto::Call,
1713    response: Response<proto::Call>,
1714    session: MessageContext,
1715) -> Result<()> {
1716    let room_id = RoomId::from_proto(request.room_id);
1717    let calling_user_id = session.user_id();
1718    let calling_connection_id = session.connection_id;
1719    let called_user_id = UserId::from_proto(request.called_user_id);
1720    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1721    if !session
1722        .db()
1723        .await
1724        .has_contact(calling_user_id, called_user_id)
1725        .await?
1726    {
1727        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1728    }
1729
1730    let incoming_call = {
1731        let (room, incoming_call) = &mut *session
1732            .db()
1733            .await
1734            .call(
1735                room_id,
1736                calling_user_id,
1737                calling_connection_id,
1738                called_user_id,
1739                initial_project_id,
1740            )
1741            .await?;
1742        room_updated(room, &session.peer);
1743        mem::take(incoming_call)
1744    };
1745    update_user_contacts(called_user_id, &session).await?;
1746
1747    let mut calls = session
1748        .connection_pool()
1749        .await
1750        .user_connection_ids(called_user_id)
1751        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1752        .collect::<FuturesUnordered<_>>();
1753
1754    while let Some(call_response) = calls.next().await {
1755        match call_response.as_ref() {
1756            Ok(_) => {
1757                response.send(proto::Ack {})?;
1758                return Ok(());
1759            }
1760            Err(_) => {
1761                call_response.trace_err();
1762            }
1763        }
1764    }
1765
1766    {
1767        let room = session
1768            .db()
1769            .await
1770            .call_failed(room_id, called_user_id)
1771            .await?;
1772        room_updated(&room, &session.peer);
1773    }
1774    update_user_contacts(called_user_id, &session).await?;
1775
1776    Err(anyhow!("failed to ring user"))?
1777}
1778
1779/// Cancel an outgoing call.
1780async fn cancel_call(
1781    request: proto::CancelCall,
1782    response: Response<proto::CancelCall>,
1783    session: MessageContext,
1784) -> Result<()> {
1785    let called_user_id = UserId::from_proto(request.called_user_id);
1786    let room_id = RoomId::from_proto(request.room_id);
1787    {
1788        let room = session
1789            .db()
1790            .await
1791            .cancel_call(room_id, session.connection_id, called_user_id)
1792            .await?;
1793        room_updated(&room, &session.peer);
1794    }
1795
1796    for connection_id in session
1797        .connection_pool()
1798        .await
1799        .user_connection_ids(called_user_id)
1800    {
1801        session
1802            .peer
1803            .send(
1804                connection_id,
1805                proto::CallCanceled {
1806                    room_id: room_id.to_proto(),
1807                },
1808            )
1809            .trace_err();
1810    }
1811    response.send(proto::Ack {})?;
1812
1813    update_user_contacts(called_user_id, &session).await?;
1814    Ok(())
1815}
1816
1817/// Decline an incoming call.
1818async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1819    let room_id = RoomId::from_proto(message.room_id);
1820    {
1821        let room = session
1822            .db()
1823            .await
1824            .decline_call(Some(room_id), session.user_id())
1825            .await?
1826            .context("declining call")?;
1827        room_updated(&room, &session.peer);
1828    }
1829
1830    for connection_id in session
1831        .connection_pool()
1832        .await
1833        .user_connection_ids(session.user_id())
1834    {
1835        session
1836            .peer
1837            .send(
1838                connection_id,
1839                proto::CallCanceled {
1840                    room_id: room_id.to_proto(),
1841                },
1842            )
1843            .trace_err();
1844    }
1845    update_user_contacts(session.user_id(), &session).await?;
1846    Ok(())
1847}
1848
1849/// Updates other participants in the room with your current location.
1850async fn update_participant_location(
1851    request: proto::UpdateParticipantLocation,
1852    response: Response<proto::UpdateParticipantLocation>,
1853    session: MessageContext,
1854) -> Result<()> {
1855    let room_id = RoomId::from_proto(request.room_id);
1856    let location = request.location.context("invalid location")?;
1857
1858    let db = session.db().await;
1859    let room = db
1860        .update_room_participant_location(room_id, session.connection_id, location)
1861        .await?;
1862
1863    room_updated(&room, &session.peer);
1864    response.send(proto::Ack {})?;
1865    Ok(())
1866}
1867
1868/// Share a project into the room.
1869async fn share_project(
1870    request: proto::ShareProject,
1871    response: Response<proto::ShareProject>,
1872    session: MessageContext,
1873) -> Result<()> {
1874    let (project_id, room) = &*session
1875        .db()
1876        .await
1877        .share_project(
1878            RoomId::from_proto(request.room_id),
1879            session.connection_id,
1880            &request.worktrees,
1881            request.is_ssh_project,
1882            request.windows_paths.unwrap_or(false),
1883        )
1884        .await?;
1885    response.send(proto::ShareProjectResponse {
1886        project_id: project_id.to_proto(),
1887    })?;
1888    room_updated(room, &session.peer);
1889
1890    Ok(())
1891}
1892
1893/// Unshare a project from the room.
1894async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1895    let project_id = ProjectId::from_proto(message.project_id);
1896    unshare_project_internal(project_id, session.connection_id, &session).await
1897}
1898
1899async fn unshare_project_internal(
1900    project_id: ProjectId,
1901    connection_id: ConnectionId,
1902    session: &Session,
1903) -> Result<()> {
1904    let delete = {
1905        let room_guard = session
1906            .db()
1907            .await
1908            .unshare_project(project_id, connection_id)
1909            .await?;
1910
1911        let (delete, room, guest_connection_ids) = &*room_guard;
1912
1913        let message = proto::UnshareProject {
1914            project_id: project_id.to_proto(),
1915        };
1916
1917        broadcast(
1918            Some(connection_id),
1919            guest_connection_ids.iter().copied(),
1920            |conn_id| session.peer.send(conn_id, message.clone()),
1921        );
1922        if let Some(room) = room {
1923            room_updated(room, &session.peer);
1924        }
1925
1926        *delete
1927    };
1928
1929    if delete {
1930        let db = session.db().await;
1931        db.delete_project(project_id).await?;
1932    }
1933
1934    Ok(())
1935}
1936
1937/// Join someone elses shared project.
1938async fn join_project(
1939    request: proto::JoinProject,
1940    response: Response<proto::JoinProject>,
1941    session: MessageContext,
1942) -> Result<()> {
1943    let project_id = ProjectId::from_proto(request.project_id);
1944
1945    tracing::info!(%project_id, "join project");
1946
1947    let db = session.db().await;
1948    let (project, replica_id) = &mut *db
1949        .join_project(
1950            project_id,
1951            session.connection_id,
1952            session.user_id(),
1953            request.committer_name.clone(),
1954            request.committer_email.clone(),
1955        )
1956        .await?;
1957    drop(db);
1958    tracing::info!(%project_id, "join remote project");
1959    let collaborators = project
1960        .collaborators
1961        .iter()
1962        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1963        .map(|collaborator| collaborator.to_proto())
1964        .collect::<Vec<_>>();
1965    let project_id = project.id;
1966    let guest_user_id = session.user_id();
1967
1968    let worktrees = project
1969        .worktrees
1970        .iter()
1971        .map(|(id, worktree)| proto::WorktreeMetadata {
1972            id: *id,
1973            root_name: worktree.root_name.clone(),
1974            visible: worktree.visible,
1975            abs_path: worktree.abs_path.clone(),
1976        })
1977        .collect::<Vec<_>>();
1978
1979    let add_project_collaborator = proto::AddProjectCollaborator {
1980        project_id: project_id.to_proto(),
1981        collaborator: Some(proto::Collaborator {
1982            peer_id: Some(session.connection_id.into()),
1983            replica_id: replica_id.0 as u32,
1984            user_id: guest_user_id.to_proto(),
1985            is_host: false,
1986            committer_name: request.committer_name.clone(),
1987            committer_email: request.committer_email.clone(),
1988        }),
1989    };
1990
1991    for collaborator in &collaborators {
1992        session
1993            .peer
1994            .send(
1995                collaborator.peer_id.unwrap().into(),
1996                add_project_collaborator.clone(),
1997            )
1998            .trace_err();
1999    }
2000
2001    // First, we send the metadata associated with each worktree.
2002    let (language_servers, language_server_capabilities) = project
2003        .language_servers
2004        .clone()
2005        .into_iter()
2006        .map(|server| (server.server, server.capabilities))
2007        .unzip();
2008    response.send(proto::JoinProjectResponse {
2009        project_id: project.id.0 as u64,
2010        worktrees,
2011        replica_id: replica_id.0 as u32,
2012        collaborators,
2013        language_servers,
2014        language_server_capabilities,
2015        role: project.role.into(),
2016        windows_paths: project.path_style == PathStyle::Windows,
2017    })?;
2018
2019    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2020        // Stream this worktree's entries.
2021        let message = proto::UpdateWorktree {
2022            project_id: project_id.to_proto(),
2023            worktree_id,
2024            abs_path: worktree.abs_path.clone(),
2025            root_name: worktree.root_name,
2026            updated_entries: worktree.entries,
2027            removed_entries: Default::default(),
2028            scan_id: worktree.scan_id,
2029            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2030            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2031            removed_repositories: Default::default(),
2032        };
2033        for update in proto::split_worktree_update(message) {
2034            session.peer.send(session.connection_id, update.clone())?;
2035        }
2036
2037        // Stream this worktree's diagnostics.
2038        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2039        if let Some(summary) = worktree_diagnostics.next() {
2040            let message = proto::UpdateDiagnosticSummary {
2041                project_id: project.id.to_proto(),
2042                worktree_id: worktree.id,
2043                summary: Some(summary),
2044                more_summaries: worktree_diagnostics.collect(),
2045            };
2046            session.peer.send(session.connection_id, message)?;
2047        }
2048
2049        for settings_file in worktree.settings_files {
2050            session.peer.send(
2051                session.connection_id,
2052                proto::UpdateWorktreeSettings {
2053                    project_id: project_id.to_proto(),
2054                    worktree_id: worktree.id,
2055                    path: settings_file.path,
2056                    content: Some(settings_file.content),
2057                    kind: Some(settings_file.kind.to_proto() as i32),
2058                },
2059            )?;
2060        }
2061    }
2062
2063    for repository in mem::take(&mut project.repositories) {
2064        for update in split_repository_update(repository) {
2065            session.peer.send(session.connection_id, update)?;
2066        }
2067    }
2068
2069    for language_server in &project.language_servers {
2070        session.peer.send(
2071            session.connection_id,
2072            proto::UpdateLanguageServer {
2073                project_id: project_id.to_proto(),
2074                server_name: Some(language_server.server.name.clone()),
2075                language_server_id: language_server.server.id,
2076                variant: Some(
2077                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2078                        proto::LspDiskBasedDiagnosticsUpdated {},
2079                    ),
2080                ),
2081            },
2082        )?;
2083    }
2084
2085    Ok(())
2086}
2087
2088/// Leave someone elses shared project.
2089async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2090    let sender_id = session.connection_id;
2091    let project_id = ProjectId::from_proto(request.project_id);
2092    let db = session.db().await;
2093
2094    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2095    tracing::info!(
2096        %project_id,
2097        "leave project"
2098    );
2099
2100    project_left(project, &session);
2101    if let Some(room) = room {
2102        room_updated(room, &session.peer);
2103    }
2104
2105    Ok(())
2106}
2107
2108/// Updates other participants with changes to the project
2109async fn update_project(
2110    request: proto::UpdateProject,
2111    response: Response<proto::UpdateProject>,
2112    session: MessageContext,
2113) -> Result<()> {
2114    let project_id = ProjectId::from_proto(request.project_id);
2115    let (room, guest_connection_ids) = &*session
2116        .db()
2117        .await
2118        .update_project(project_id, session.connection_id, &request.worktrees)
2119        .await?;
2120    broadcast(
2121        Some(session.connection_id),
2122        guest_connection_ids.iter().copied(),
2123        |connection_id| {
2124            session
2125                .peer
2126                .forward_send(session.connection_id, connection_id, request.clone())
2127        },
2128    );
2129    if let Some(room) = room {
2130        room_updated(room, &session.peer);
2131    }
2132    response.send(proto::Ack {})?;
2133
2134    Ok(())
2135}
2136
2137/// Updates other participants with changes to the worktree
2138async fn update_worktree(
2139    request: proto::UpdateWorktree,
2140    response: Response<proto::UpdateWorktree>,
2141    session: MessageContext,
2142) -> Result<()> {
2143    let guest_connection_ids = session
2144        .db()
2145        .await
2146        .update_worktree(&request, session.connection_id)
2147        .await?;
2148
2149    broadcast(
2150        Some(session.connection_id),
2151        guest_connection_ids.iter().copied(),
2152        |connection_id| {
2153            session
2154                .peer
2155                .forward_send(session.connection_id, connection_id, request.clone())
2156        },
2157    );
2158    response.send(proto::Ack {})?;
2159    Ok(())
2160}
2161
2162async fn update_repository(
2163    request: proto::UpdateRepository,
2164    response: Response<proto::UpdateRepository>,
2165    session: MessageContext,
2166) -> Result<()> {
2167    let guest_connection_ids = session
2168        .db()
2169        .await
2170        .update_repository(&request, session.connection_id)
2171        .await?;
2172
2173    broadcast(
2174        Some(session.connection_id),
2175        guest_connection_ids.iter().copied(),
2176        |connection_id| {
2177            session
2178                .peer
2179                .forward_send(session.connection_id, connection_id, request.clone())
2180        },
2181    );
2182    response.send(proto::Ack {})?;
2183    Ok(())
2184}
2185
2186async fn remove_repository(
2187    request: proto::RemoveRepository,
2188    response: Response<proto::RemoveRepository>,
2189    session: MessageContext,
2190) -> Result<()> {
2191    let guest_connection_ids = session
2192        .db()
2193        .await
2194        .remove_repository(&request, session.connection_id)
2195        .await?;
2196
2197    broadcast(
2198        Some(session.connection_id),
2199        guest_connection_ids.iter().copied(),
2200        |connection_id| {
2201            session
2202                .peer
2203                .forward_send(session.connection_id, connection_id, request.clone())
2204        },
2205    );
2206    response.send(proto::Ack {})?;
2207    Ok(())
2208}
2209
2210/// Updates other participants with changes to the diagnostics
2211async fn update_diagnostic_summary(
2212    message: proto::UpdateDiagnosticSummary,
2213    session: MessageContext,
2214) -> Result<()> {
2215    let guest_connection_ids = session
2216        .db()
2217        .await
2218        .update_diagnostic_summary(&message, session.connection_id)
2219        .await?;
2220
2221    broadcast(
2222        Some(session.connection_id),
2223        guest_connection_ids.iter().copied(),
2224        |connection_id| {
2225            session
2226                .peer
2227                .forward_send(session.connection_id, connection_id, message.clone())
2228        },
2229    );
2230
2231    Ok(())
2232}
2233
2234/// Updates other participants with changes to the worktree settings
2235async fn update_worktree_settings(
2236    message: proto::UpdateWorktreeSettings,
2237    session: MessageContext,
2238) -> Result<()> {
2239    let guest_connection_ids = session
2240        .db()
2241        .await
2242        .update_worktree_settings(&message, session.connection_id)
2243        .await?;
2244
2245    broadcast(
2246        Some(session.connection_id),
2247        guest_connection_ids.iter().copied(),
2248        |connection_id| {
2249            session
2250                .peer
2251                .forward_send(session.connection_id, connection_id, message.clone())
2252        },
2253    );
2254
2255    Ok(())
2256}
2257
2258/// Notify other participants that a language server has started.
2259async fn start_language_server(
2260    request: proto::StartLanguageServer,
2261    session: MessageContext,
2262) -> Result<()> {
2263    let guest_connection_ids = session
2264        .db()
2265        .await
2266        .start_language_server(&request, session.connection_id)
2267        .await?;
2268
2269    broadcast(
2270        Some(session.connection_id),
2271        guest_connection_ids.iter().copied(),
2272        |connection_id| {
2273            session
2274                .peer
2275                .forward_send(session.connection_id, connection_id, request.clone())
2276        },
2277    );
2278    Ok(())
2279}
2280
2281/// Notify other participants that a language server has changed.
2282async fn update_language_server(
2283    request: proto::UpdateLanguageServer,
2284    session: MessageContext,
2285) -> Result<()> {
2286    let project_id = ProjectId::from_proto(request.project_id);
2287    let db = session.db().await;
2288
2289    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2290        && let Some(capabilities) = update.capabilities.clone()
2291    {
2292        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2293            .await?;
2294    }
2295
2296    let project_connection_ids = db
2297        .project_connection_ids(project_id, session.connection_id, true)
2298        .await?;
2299    broadcast(
2300        Some(session.connection_id),
2301        project_connection_ids.iter().copied(),
2302        |connection_id| {
2303            session
2304                .peer
2305                .forward_send(session.connection_id, connection_id, request.clone())
2306        },
2307    );
2308    Ok(())
2309}
2310
2311/// forward a project request to the host. These requests should be read only
2312/// as guests are allowed to send them.
2313async fn forward_read_only_project_request<T>(
2314    request: T,
2315    response: Response<T>,
2316    session: MessageContext,
2317) -> Result<()>
2318where
2319    T: EntityMessage + RequestMessage,
2320{
2321    let project_id = ProjectId::from_proto(request.remote_entity_id());
2322    let host_connection_id = session
2323        .db()
2324        .await
2325        .host_for_read_only_project_request(project_id, session.connection_id)
2326        .await?;
2327    let payload = session.forward_request(host_connection_id, request).await?;
2328    response.send(payload)?;
2329    Ok(())
2330}
2331
2332/// forward a project request to the host. These requests are disallowed
2333/// for guests.
2334async fn forward_mutating_project_request<T>(
2335    request: T,
2336    response: Response<T>,
2337    session: MessageContext,
2338) -> Result<()>
2339where
2340    T: EntityMessage + RequestMessage,
2341{
2342    let project_id = ProjectId::from_proto(request.remote_entity_id());
2343
2344    let host_connection_id = session
2345        .db()
2346        .await
2347        .host_for_mutating_project_request(project_id, session.connection_id)
2348        .await?;
2349    let payload = session.forward_request(host_connection_id, request).await?;
2350    response.send(payload)?;
2351    Ok(())
2352}
2353
2354async fn lsp_query(
2355    request: proto::LspQuery,
2356    response: Response<proto::LspQuery>,
2357    session: MessageContext,
2358) -> Result<()> {
2359    let (name, should_write) = request.query_name_and_write_permissions();
2360    tracing::Span::current().record("lsp_query_request", name);
2361    tracing::info!("lsp_query message received");
2362    if should_write {
2363        forward_mutating_project_request(request, response, session).await
2364    } else {
2365        forward_read_only_project_request(request, response, session).await
2366    }
2367}
2368
2369/// Notify other participants that a new buffer has been created
2370async fn create_buffer_for_peer(
2371    request: proto::CreateBufferForPeer,
2372    session: MessageContext,
2373) -> Result<()> {
2374    session
2375        .db()
2376        .await
2377        .check_user_is_project_host(
2378            ProjectId::from_proto(request.project_id),
2379            session.connection_id,
2380        )
2381        .await?;
2382    let peer_id = request.peer_id.context("invalid peer id")?;
2383    session
2384        .peer
2385        .forward_send(session.connection_id, peer_id.into(), request)?;
2386    Ok(())
2387}
2388
2389/// Notify other participants that a buffer has been updated. This is
2390/// allowed for guests as long as the update is limited to selections.
2391async fn update_buffer(
2392    request: proto::UpdateBuffer,
2393    response: Response<proto::UpdateBuffer>,
2394    session: MessageContext,
2395) -> Result<()> {
2396    let project_id = ProjectId::from_proto(request.project_id);
2397    let mut capability = Capability::ReadOnly;
2398
2399    for op in request.operations.iter() {
2400        match op.variant {
2401            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2402            Some(_) => capability = Capability::ReadWrite,
2403        }
2404    }
2405
2406    let host = {
2407        let guard = session
2408            .db()
2409            .await
2410            .connections_for_buffer_update(project_id, session.connection_id, capability)
2411            .await?;
2412
2413        let (host, guests) = &*guard;
2414
2415        broadcast(
2416            Some(session.connection_id),
2417            guests.clone(),
2418            |connection_id| {
2419                session
2420                    .peer
2421                    .forward_send(session.connection_id, connection_id, request.clone())
2422            },
2423        );
2424
2425        *host
2426    };
2427
2428    if host != session.connection_id {
2429        session.forward_request(host, request.clone()).await?;
2430    }
2431
2432    response.send(proto::Ack {})?;
2433    Ok(())
2434}
2435
2436async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2437    let project_id = ProjectId::from_proto(message.project_id);
2438
2439    let operation = message.operation.as_ref().context("invalid operation")?;
2440    let capability = match operation.variant.as_ref() {
2441        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2442            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2443                match buffer_op.variant {
2444                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2445                        Capability::ReadOnly
2446                    }
2447                    _ => Capability::ReadWrite,
2448                }
2449            } else {
2450                Capability::ReadWrite
2451            }
2452        }
2453        Some(_) => Capability::ReadWrite,
2454        None => Capability::ReadOnly,
2455    };
2456
2457    let guard = session
2458        .db()
2459        .await
2460        .connections_for_buffer_update(project_id, session.connection_id, capability)
2461        .await?;
2462
2463    let (host, guests) = &*guard;
2464
2465    broadcast(
2466        Some(session.connection_id),
2467        guests.iter().chain([host]).copied(),
2468        |connection_id| {
2469            session
2470                .peer
2471                .forward_send(session.connection_id, connection_id, message.clone())
2472        },
2473    );
2474
2475    Ok(())
2476}
2477
2478/// Notify other participants that a project has been updated.
2479async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2480    request: T,
2481    session: MessageContext,
2482) -> Result<()> {
2483    let project_id = ProjectId::from_proto(request.remote_entity_id());
2484    let project_connection_ids = session
2485        .db()
2486        .await
2487        .project_connection_ids(project_id, session.connection_id, false)
2488        .await?;
2489
2490    broadcast(
2491        Some(session.connection_id),
2492        project_connection_ids.iter().copied(),
2493        |connection_id| {
2494            session
2495                .peer
2496                .forward_send(session.connection_id, connection_id, request.clone())
2497        },
2498    );
2499    Ok(())
2500}
2501
2502/// Start following another user in a call.
2503async fn follow(
2504    request: proto::Follow,
2505    response: Response<proto::Follow>,
2506    session: MessageContext,
2507) -> Result<()> {
2508    let room_id = RoomId::from_proto(request.room_id);
2509    let project_id = request.project_id.map(ProjectId::from_proto);
2510    let leader_id = request.leader_id.context("invalid leader id")?.into();
2511    let follower_id = session.connection_id;
2512
2513    session
2514        .db()
2515        .await
2516        .check_room_participants(room_id, leader_id, session.connection_id)
2517        .await?;
2518
2519    let response_payload = session.forward_request(leader_id, request).await?;
2520    response.send(response_payload)?;
2521
2522    if let Some(project_id) = project_id {
2523        let room = session
2524            .db()
2525            .await
2526            .follow(room_id, project_id, leader_id, follower_id)
2527            .await?;
2528        room_updated(&room, &session.peer);
2529    }
2530
2531    Ok(())
2532}
2533
2534/// Stop following another user in a call.
2535async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2536    let room_id = RoomId::from_proto(request.room_id);
2537    let project_id = request.project_id.map(ProjectId::from_proto);
2538    let leader_id = request.leader_id.context("invalid leader id")?.into();
2539    let follower_id = session.connection_id;
2540
2541    session
2542        .db()
2543        .await
2544        .check_room_participants(room_id, leader_id, session.connection_id)
2545        .await?;
2546
2547    session
2548        .peer
2549        .forward_send(session.connection_id, leader_id, request)?;
2550
2551    if let Some(project_id) = project_id {
2552        let room = session
2553            .db()
2554            .await
2555            .unfollow(room_id, project_id, leader_id, follower_id)
2556            .await?;
2557        room_updated(&room, &session.peer);
2558    }
2559
2560    Ok(())
2561}
2562
2563/// Notify everyone following you of your current location.
2564async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2565    let room_id = RoomId::from_proto(request.room_id);
2566    let database = session.db.lock().await;
2567
2568    let connection_ids = if let Some(project_id) = request.project_id {
2569        let project_id = ProjectId::from_proto(project_id);
2570        database
2571            .project_connection_ids(project_id, session.connection_id, true)
2572            .await?
2573    } else {
2574        database
2575            .room_connection_ids(room_id, session.connection_id)
2576            .await?
2577    };
2578
2579    // For now, don't send view update messages back to that view's current leader.
2580    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2581        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2582        _ => None,
2583    });
2584
2585    for connection_id in connection_ids.iter().cloned() {
2586        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2587            session
2588                .peer
2589                .forward_send(session.connection_id, connection_id, request.clone())?;
2590        }
2591    }
2592    Ok(())
2593}
2594
2595/// Get public data about users.
2596async fn get_users(
2597    request: proto::GetUsers,
2598    response: Response<proto::GetUsers>,
2599    session: MessageContext,
2600) -> Result<()> {
2601    let user_ids = request
2602        .user_ids
2603        .into_iter()
2604        .map(UserId::from_proto)
2605        .collect();
2606    let users = session
2607        .db()
2608        .await
2609        .get_users_by_ids(user_ids)
2610        .await?
2611        .into_iter()
2612        .map(|user| proto::User {
2613            id: user.id.to_proto(),
2614            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2615            github_login: user.github_login,
2616            name: user.name,
2617        })
2618        .collect();
2619    response.send(proto::UsersResponse { users })?;
2620    Ok(())
2621}
2622
2623/// Search for users (to invite) buy Github login
2624async fn fuzzy_search_users(
2625    request: proto::FuzzySearchUsers,
2626    response: Response<proto::FuzzySearchUsers>,
2627    session: MessageContext,
2628) -> Result<()> {
2629    let query = request.query;
2630    let users = match query.len() {
2631        0 => vec![],
2632        1 | 2 => session
2633            .db()
2634            .await
2635            .get_user_by_github_login(&query)
2636            .await?
2637            .into_iter()
2638            .collect(),
2639        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2640    };
2641    let users = users
2642        .into_iter()
2643        .filter(|user| user.id != session.user_id())
2644        .map(|user| proto::User {
2645            id: user.id.to_proto(),
2646            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2647            github_login: user.github_login,
2648            name: user.name,
2649        })
2650        .collect();
2651    response.send(proto::UsersResponse { users })?;
2652    Ok(())
2653}
2654
2655/// Send a contact request to another user.
2656async fn request_contact(
2657    request: proto::RequestContact,
2658    response: Response<proto::RequestContact>,
2659    session: MessageContext,
2660) -> Result<()> {
2661    let requester_id = session.user_id();
2662    let responder_id = UserId::from_proto(request.responder_id);
2663    if requester_id == responder_id {
2664        return Err(anyhow!("cannot add yourself as a contact"))?;
2665    }
2666
2667    let notifications = session
2668        .db()
2669        .await
2670        .send_contact_request(requester_id, responder_id)
2671        .await?;
2672
2673    // Update outgoing contact requests of requester
2674    let mut update = proto::UpdateContacts::default();
2675    update.outgoing_requests.push(responder_id.to_proto());
2676    for connection_id in session
2677        .connection_pool()
2678        .await
2679        .user_connection_ids(requester_id)
2680    {
2681        session.peer.send(connection_id, update.clone())?;
2682    }
2683
2684    // Update incoming contact requests of responder
2685    let mut update = proto::UpdateContacts::default();
2686    update
2687        .incoming_requests
2688        .push(proto::IncomingContactRequest {
2689            requester_id: requester_id.to_proto(),
2690        });
2691    let connection_pool = session.connection_pool().await;
2692    for connection_id in connection_pool.user_connection_ids(responder_id) {
2693        session.peer.send(connection_id, update.clone())?;
2694    }
2695
2696    send_notifications(&connection_pool, &session.peer, notifications);
2697
2698    response.send(proto::Ack {})?;
2699    Ok(())
2700}
2701
2702/// Accept or decline a contact request
2703async fn respond_to_contact_request(
2704    request: proto::RespondToContactRequest,
2705    response: Response<proto::RespondToContactRequest>,
2706    session: MessageContext,
2707) -> Result<()> {
2708    let responder_id = session.user_id();
2709    let requester_id = UserId::from_proto(request.requester_id);
2710    let db = session.db().await;
2711    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2712        db.dismiss_contact_notification(responder_id, requester_id)
2713            .await?;
2714    } else {
2715        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2716
2717        let notifications = db
2718            .respond_to_contact_request(responder_id, requester_id, accept)
2719            .await?;
2720        let requester_busy = db.is_user_busy(requester_id).await?;
2721        let responder_busy = db.is_user_busy(responder_id).await?;
2722
2723        let pool = session.connection_pool().await;
2724        // Update responder with new contact
2725        let mut update = proto::UpdateContacts::default();
2726        if accept {
2727            update
2728                .contacts
2729                .push(contact_for_user(requester_id, requester_busy, &pool));
2730        }
2731        update
2732            .remove_incoming_requests
2733            .push(requester_id.to_proto());
2734        for connection_id in pool.user_connection_ids(responder_id) {
2735            session.peer.send(connection_id, update.clone())?;
2736        }
2737
2738        // Update requester with new contact
2739        let mut update = proto::UpdateContacts::default();
2740        if accept {
2741            update
2742                .contacts
2743                .push(contact_for_user(responder_id, responder_busy, &pool));
2744        }
2745        update
2746            .remove_outgoing_requests
2747            .push(responder_id.to_proto());
2748
2749        for connection_id in pool.user_connection_ids(requester_id) {
2750            session.peer.send(connection_id, update.clone())?;
2751        }
2752
2753        send_notifications(&pool, &session.peer, notifications);
2754    }
2755
2756    response.send(proto::Ack {})?;
2757    Ok(())
2758}
2759
2760/// Remove a contact.
2761async fn remove_contact(
2762    request: proto::RemoveContact,
2763    response: Response<proto::RemoveContact>,
2764    session: MessageContext,
2765) -> Result<()> {
2766    let requester_id = session.user_id();
2767    let responder_id = UserId::from_proto(request.user_id);
2768    let db = session.db().await;
2769    let (contact_accepted, deleted_notification_id) =
2770        db.remove_contact(requester_id, responder_id).await?;
2771
2772    let pool = session.connection_pool().await;
2773    // Update outgoing contact requests of requester
2774    let mut update = proto::UpdateContacts::default();
2775    if contact_accepted {
2776        update.remove_contacts.push(responder_id.to_proto());
2777    } else {
2778        update
2779            .remove_outgoing_requests
2780            .push(responder_id.to_proto());
2781    }
2782    for connection_id in pool.user_connection_ids(requester_id) {
2783        session.peer.send(connection_id, update.clone())?;
2784    }
2785
2786    // Update incoming contact requests of responder
2787    let mut update = proto::UpdateContacts::default();
2788    if contact_accepted {
2789        update.remove_contacts.push(requester_id.to_proto());
2790    } else {
2791        update
2792            .remove_incoming_requests
2793            .push(requester_id.to_proto());
2794    }
2795    for connection_id in pool.user_connection_ids(responder_id) {
2796        session.peer.send(connection_id, update.clone())?;
2797        if let Some(notification_id) = deleted_notification_id {
2798            session.peer.send(
2799                connection_id,
2800                proto::DeleteNotification {
2801                    notification_id: notification_id.to_proto(),
2802                },
2803            )?;
2804        }
2805    }
2806
2807    response.send(proto::Ack {})?;
2808    Ok(())
2809}
2810
2811fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2812    version.0.minor() < 139
2813}
2814
2815async fn subscribe_to_channels(
2816    _: proto::SubscribeToChannels,
2817    session: MessageContext,
2818) -> Result<()> {
2819    subscribe_user_to_channels(session.user_id(), &session).await?;
2820    Ok(())
2821}
2822
2823async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2824    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2825    let mut pool = session.connection_pool().await;
2826    for membership in &channels_for_user.channel_memberships {
2827        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2828    }
2829    session.peer.send(
2830        session.connection_id,
2831        build_update_user_channels(&channels_for_user),
2832    )?;
2833    session.peer.send(
2834        session.connection_id,
2835        build_channels_update(channels_for_user),
2836    )?;
2837    Ok(())
2838}
2839
2840/// Creates a new channel.
2841async fn create_channel(
2842    request: proto::CreateChannel,
2843    response: Response<proto::CreateChannel>,
2844    session: MessageContext,
2845) -> Result<()> {
2846    let db = session.db().await;
2847
2848    let parent_id = request.parent_id.map(ChannelId::from_proto);
2849    let (channel, membership) = db
2850        .create_channel(&request.name, parent_id, session.user_id())
2851        .await?;
2852
2853    let root_id = channel.root_id();
2854    let channel = Channel::from_model(channel);
2855
2856    response.send(proto::CreateChannelResponse {
2857        channel: Some(channel.to_proto()),
2858        parent_id: request.parent_id,
2859    })?;
2860
2861    let mut connection_pool = session.connection_pool().await;
2862    if let Some(membership) = membership {
2863        connection_pool.subscribe_to_channel(
2864            membership.user_id,
2865            membership.channel_id,
2866            membership.role,
2867        );
2868        let update = proto::UpdateUserChannels {
2869            channel_memberships: vec![proto::ChannelMembership {
2870                channel_id: membership.channel_id.to_proto(),
2871                role: membership.role.into(),
2872            }],
2873            ..Default::default()
2874        };
2875        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2876            session.peer.send(connection_id, update.clone())?;
2877        }
2878    }
2879
2880    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2881        if !role.can_see_channel(channel.visibility) {
2882            continue;
2883        }
2884
2885        let update = proto::UpdateChannels {
2886            channels: vec![channel.to_proto()],
2887            ..Default::default()
2888        };
2889        session.peer.send(connection_id, update.clone())?;
2890    }
2891
2892    Ok(())
2893}
2894
2895/// Delete a channel
2896async fn delete_channel(
2897    request: proto::DeleteChannel,
2898    response: Response<proto::DeleteChannel>,
2899    session: MessageContext,
2900) -> Result<()> {
2901    let db = session.db().await;
2902
2903    let channel_id = request.channel_id;
2904    let (root_channel, removed_channels) = db
2905        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2906        .await?;
2907    response.send(proto::Ack {})?;
2908
2909    // Notify members of removed channels
2910    let mut update = proto::UpdateChannels::default();
2911    update
2912        .delete_channels
2913        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2914
2915    let connection_pool = session.connection_pool().await;
2916    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2917        session.peer.send(connection_id, update.clone())?;
2918    }
2919
2920    Ok(())
2921}
2922
2923/// Invite someone to join a channel.
2924async fn invite_channel_member(
2925    request: proto::InviteChannelMember,
2926    response: Response<proto::InviteChannelMember>,
2927    session: MessageContext,
2928) -> Result<()> {
2929    let db = session.db().await;
2930    let channel_id = ChannelId::from_proto(request.channel_id);
2931    let invitee_id = UserId::from_proto(request.user_id);
2932    let InviteMemberResult {
2933        channel,
2934        notifications,
2935    } = db
2936        .invite_channel_member(
2937            channel_id,
2938            invitee_id,
2939            session.user_id(),
2940            request.role().into(),
2941        )
2942        .await?;
2943
2944    let update = proto::UpdateChannels {
2945        channel_invitations: vec![channel.to_proto()],
2946        ..Default::default()
2947    };
2948
2949    let connection_pool = session.connection_pool().await;
2950    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2951        session.peer.send(connection_id, update.clone())?;
2952    }
2953
2954    send_notifications(&connection_pool, &session.peer, notifications);
2955
2956    response.send(proto::Ack {})?;
2957    Ok(())
2958}
2959
2960/// remove someone from a channel
2961async fn remove_channel_member(
2962    request: proto::RemoveChannelMember,
2963    response: Response<proto::RemoveChannelMember>,
2964    session: MessageContext,
2965) -> Result<()> {
2966    let db = session.db().await;
2967    let channel_id = ChannelId::from_proto(request.channel_id);
2968    let member_id = UserId::from_proto(request.user_id);
2969
2970    let RemoveChannelMemberResult {
2971        membership_update,
2972        notification_id,
2973    } = db
2974        .remove_channel_member(channel_id, member_id, session.user_id())
2975        .await?;
2976
2977    let mut connection_pool = session.connection_pool().await;
2978    notify_membership_updated(
2979        &mut connection_pool,
2980        membership_update,
2981        member_id,
2982        &session.peer,
2983    );
2984    for connection_id in connection_pool.user_connection_ids(member_id) {
2985        if let Some(notification_id) = notification_id {
2986            session
2987                .peer
2988                .send(
2989                    connection_id,
2990                    proto::DeleteNotification {
2991                        notification_id: notification_id.to_proto(),
2992                    },
2993                )
2994                .trace_err();
2995        }
2996    }
2997
2998    response.send(proto::Ack {})?;
2999    Ok(())
3000}
3001
3002/// Toggle the channel between public and private.
3003/// Care is taken to maintain the invariant that public channels only descend from public channels,
3004/// (though members-only channels can appear at any point in the hierarchy).
3005async fn set_channel_visibility(
3006    request: proto::SetChannelVisibility,
3007    response: Response<proto::SetChannelVisibility>,
3008    session: MessageContext,
3009) -> Result<()> {
3010    let db = session.db().await;
3011    let channel_id = ChannelId::from_proto(request.channel_id);
3012    let visibility = request.visibility().into();
3013
3014    let channel_model = db
3015        .set_channel_visibility(channel_id, visibility, session.user_id())
3016        .await?;
3017    let root_id = channel_model.root_id();
3018    let channel = Channel::from_model(channel_model);
3019
3020    let mut connection_pool = session.connection_pool().await;
3021    for (user_id, role) in connection_pool
3022        .channel_user_ids(root_id)
3023        .collect::<Vec<_>>()
3024        .into_iter()
3025    {
3026        let update = if role.can_see_channel(channel.visibility) {
3027            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3028            proto::UpdateChannels {
3029                channels: vec![channel.to_proto()],
3030                ..Default::default()
3031            }
3032        } else {
3033            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3034            proto::UpdateChannels {
3035                delete_channels: vec![channel.id.to_proto()],
3036                ..Default::default()
3037            }
3038        };
3039
3040        for connection_id in connection_pool.user_connection_ids(user_id) {
3041            session.peer.send(connection_id, update.clone())?;
3042        }
3043    }
3044
3045    response.send(proto::Ack {})?;
3046    Ok(())
3047}
3048
3049/// Alter the role for a user in the channel.
3050async fn set_channel_member_role(
3051    request: proto::SetChannelMemberRole,
3052    response: Response<proto::SetChannelMemberRole>,
3053    session: MessageContext,
3054) -> Result<()> {
3055    let db = session.db().await;
3056    let channel_id = ChannelId::from_proto(request.channel_id);
3057    let member_id = UserId::from_proto(request.user_id);
3058    let result = db
3059        .set_channel_member_role(
3060            channel_id,
3061            session.user_id(),
3062            member_id,
3063            request.role().into(),
3064        )
3065        .await?;
3066
3067    match result {
3068        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3069            let mut connection_pool = session.connection_pool().await;
3070            notify_membership_updated(
3071                &mut connection_pool,
3072                membership_update,
3073                member_id,
3074                &session.peer,
3075            )
3076        }
3077        db::SetMemberRoleResult::InviteUpdated(channel) => {
3078            let update = proto::UpdateChannels {
3079                channel_invitations: vec![channel.to_proto()],
3080                ..Default::default()
3081            };
3082
3083            for connection_id in session
3084                .connection_pool()
3085                .await
3086                .user_connection_ids(member_id)
3087            {
3088                session.peer.send(connection_id, update.clone())?;
3089            }
3090        }
3091    }
3092
3093    response.send(proto::Ack {})?;
3094    Ok(())
3095}
3096
3097/// Change the name of a channel
3098async fn rename_channel(
3099    request: proto::RenameChannel,
3100    response: Response<proto::RenameChannel>,
3101    session: MessageContext,
3102) -> Result<()> {
3103    let db = session.db().await;
3104    let channel_id = ChannelId::from_proto(request.channel_id);
3105    let channel_model = db
3106        .rename_channel(channel_id, session.user_id(), &request.name)
3107        .await?;
3108    let root_id = channel_model.root_id();
3109    let channel = Channel::from_model(channel_model);
3110
3111    response.send(proto::RenameChannelResponse {
3112        channel: Some(channel.to_proto()),
3113    })?;
3114
3115    let connection_pool = session.connection_pool().await;
3116    let update = proto::UpdateChannels {
3117        channels: vec![channel.to_proto()],
3118        ..Default::default()
3119    };
3120    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3121        if role.can_see_channel(channel.visibility) {
3122            session.peer.send(connection_id, update.clone())?;
3123        }
3124    }
3125
3126    Ok(())
3127}
3128
3129/// Move a channel to a new parent.
3130async fn move_channel(
3131    request: proto::MoveChannel,
3132    response: Response<proto::MoveChannel>,
3133    session: MessageContext,
3134) -> Result<()> {
3135    let channel_id = ChannelId::from_proto(request.channel_id);
3136    let to = ChannelId::from_proto(request.to);
3137
3138    let (root_id, channels) = session
3139        .db()
3140        .await
3141        .move_channel(channel_id, to, session.user_id())
3142        .await?;
3143
3144    let connection_pool = session.connection_pool().await;
3145    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3146        let channels = channels
3147            .iter()
3148            .filter_map(|channel| {
3149                if role.can_see_channel(channel.visibility) {
3150                    Some(channel.to_proto())
3151                } else {
3152                    None
3153                }
3154            })
3155            .collect::<Vec<_>>();
3156        if channels.is_empty() {
3157            continue;
3158        }
3159
3160        let update = proto::UpdateChannels {
3161            channels,
3162            ..Default::default()
3163        };
3164
3165        session.peer.send(connection_id, update.clone())?;
3166    }
3167
3168    response.send(Ack {})?;
3169    Ok(())
3170}
3171
3172async fn reorder_channel(
3173    request: proto::ReorderChannel,
3174    response: Response<proto::ReorderChannel>,
3175    session: MessageContext,
3176) -> Result<()> {
3177    let channel_id = ChannelId::from_proto(request.channel_id);
3178    let direction = request.direction();
3179
3180    let updated_channels = session
3181        .db()
3182        .await
3183        .reorder_channel(channel_id, direction, session.user_id())
3184        .await?;
3185
3186    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3187        let connection_pool = session.connection_pool().await;
3188        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3189            let channels = updated_channels
3190                .iter()
3191                .filter_map(|channel| {
3192                    if role.can_see_channel(channel.visibility) {
3193                        Some(channel.to_proto())
3194                    } else {
3195                        None
3196                    }
3197                })
3198                .collect::<Vec<_>>();
3199
3200            if channels.is_empty() {
3201                continue;
3202            }
3203
3204            let update = proto::UpdateChannels {
3205                channels,
3206                ..Default::default()
3207            };
3208
3209            session.peer.send(connection_id, update.clone())?;
3210        }
3211    }
3212
3213    response.send(Ack {})?;
3214    Ok(())
3215}
3216
3217/// Get the list of channel members
3218async fn get_channel_members(
3219    request: proto::GetChannelMembers,
3220    response: Response<proto::GetChannelMembers>,
3221    session: MessageContext,
3222) -> Result<()> {
3223    let db = session.db().await;
3224    let channel_id = ChannelId::from_proto(request.channel_id);
3225    let limit = if request.limit == 0 {
3226        u16::MAX as u64
3227    } else {
3228        request.limit
3229    };
3230    let (members, users) = db
3231        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3232        .await?;
3233    response.send(proto::GetChannelMembersResponse { members, users })?;
3234    Ok(())
3235}
3236
3237/// Accept or decline a channel invitation.
3238async fn respond_to_channel_invite(
3239    request: proto::RespondToChannelInvite,
3240    response: Response<proto::RespondToChannelInvite>,
3241    session: MessageContext,
3242) -> Result<()> {
3243    let db = session.db().await;
3244    let channel_id = ChannelId::from_proto(request.channel_id);
3245    let RespondToChannelInvite {
3246        membership_update,
3247        notifications,
3248    } = db
3249        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3250        .await?;
3251
3252    let mut connection_pool = session.connection_pool().await;
3253    if let Some(membership_update) = membership_update {
3254        notify_membership_updated(
3255            &mut connection_pool,
3256            membership_update,
3257            session.user_id(),
3258            &session.peer,
3259        );
3260    } else {
3261        let update = proto::UpdateChannels {
3262            remove_channel_invitations: vec![channel_id.to_proto()],
3263            ..Default::default()
3264        };
3265
3266        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3267            session.peer.send(connection_id, update.clone())?;
3268        }
3269    };
3270
3271    send_notifications(&connection_pool, &session.peer, notifications);
3272
3273    response.send(proto::Ack {})?;
3274
3275    Ok(())
3276}
3277
3278/// Join the channels' room
3279async fn join_channel(
3280    request: proto::JoinChannel,
3281    response: Response<proto::JoinChannel>,
3282    session: MessageContext,
3283) -> Result<()> {
3284    let channel_id = ChannelId::from_proto(request.channel_id);
3285    join_channel_internal(channel_id, Box::new(response), session).await
3286}
3287
3288trait JoinChannelInternalResponse {
3289    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3290}
3291impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3292    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3293        Response::<proto::JoinChannel>::send(self, result)
3294    }
3295}
3296impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3297    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3298        Response::<proto::JoinRoom>::send(self, result)
3299    }
3300}
3301
3302async fn join_channel_internal(
3303    channel_id: ChannelId,
3304    response: Box<impl JoinChannelInternalResponse>,
3305    session: MessageContext,
3306) -> Result<()> {
3307    let joined_room = {
3308        let mut db = session.db().await;
3309        // If zed quits without leaving the room, and the user re-opens zed before the
3310        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3311        // room they were in.
3312        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3313            tracing::info!(
3314                stale_connection_id = %connection,
3315                "cleaning up stale connection",
3316            );
3317            drop(db);
3318            leave_room_for_session(&session, connection).await?;
3319            db = session.db().await;
3320        }
3321
3322        let (joined_room, membership_updated, role) = db
3323            .join_channel(channel_id, session.user_id(), session.connection_id)
3324            .await?;
3325
3326        let live_kit_connection_info =
3327            session
3328                .app_state
3329                .livekit_client
3330                .as_ref()
3331                .and_then(|live_kit| {
3332                    let (can_publish, token) = if role == ChannelRole::Guest {
3333                        (
3334                            false,
3335                            live_kit
3336                                .guest_token(
3337                                    &joined_room.room.livekit_room,
3338                                    &session.user_id().to_string(),
3339                                )
3340                                .trace_err()?,
3341                        )
3342                    } else {
3343                        (
3344                            true,
3345                            live_kit
3346                                .room_token(
3347                                    &joined_room.room.livekit_room,
3348                                    &session.user_id().to_string(),
3349                                )
3350                                .trace_err()?,
3351                        )
3352                    };
3353
3354                    Some(LiveKitConnectionInfo {
3355                        server_url: live_kit.url().into(),
3356                        token,
3357                        can_publish,
3358                    })
3359                });
3360
3361        response.send(proto::JoinRoomResponse {
3362            room: Some(joined_room.room.clone()),
3363            channel_id: joined_room
3364                .channel
3365                .as_ref()
3366                .map(|channel| channel.id.to_proto()),
3367            live_kit_connection_info,
3368        })?;
3369
3370        let mut connection_pool = session.connection_pool().await;
3371        if let Some(membership_updated) = membership_updated {
3372            notify_membership_updated(
3373                &mut connection_pool,
3374                membership_updated,
3375                session.user_id(),
3376                &session.peer,
3377            );
3378        }
3379
3380        room_updated(&joined_room.room, &session.peer);
3381
3382        joined_room
3383    };
3384
3385    channel_updated(
3386        &joined_room.channel.context("channel not returned")?,
3387        &joined_room.room,
3388        &session.peer,
3389        &*session.connection_pool().await,
3390    );
3391
3392    update_user_contacts(session.user_id(), &session).await?;
3393    Ok(())
3394}
3395
3396/// Start editing the channel notes
3397async fn join_channel_buffer(
3398    request: proto::JoinChannelBuffer,
3399    response: Response<proto::JoinChannelBuffer>,
3400    session: MessageContext,
3401) -> Result<()> {
3402    let db = session.db().await;
3403    let channel_id = ChannelId::from_proto(request.channel_id);
3404
3405    let open_response = db
3406        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3407        .await?;
3408
3409    let collaborators = open_response.collaborators.clone();
3410    response.send(open_response)?;
3411
3412    let update = UpdateChannelBufferCollaborators {
3413        channel_id: channel_id.to_proto(),
3414        collaborators: collaborators.clone(),
3415    };
3416    channel_buffer_updated(
3417        session.connection_id,
3418        collaborators
3419            .iter()
3420            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3421        &update,
3422        &session.peer,
3423    );
3424
3425    Ok(())
3426}
3427
3428/// Edit the channel notes
3429async fn update_channel_buffer(
3430    request: proto::UpdateChannelBuffer,
3431    session: MessageContext,
3432) -> Result<()> {
3433    let db = session.db().await;
3434    let channel_id = ChannelId::from_proto(request.channel_id);
3435
3436    let (collaborators, epoch, version) = db
3437        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3438        .await?;
3439
3440    channel_buffer_updated(
3441        session.connection_id,
3442        collaborators.clone(),
3443        &proto::UpdateChannelBuffer {
3444            channel_id: channel_id.to_proto(),
3445            operations: request.operations,
3446        },
3447        &session.peer,
3448    );
3449
3450    let pool = &*session.connection_pool().await;
3451
3452    let non_collaborators =
3453        pool.channel_connection_ids(channel_id)
3454            .filter_map(|(connection_id, _)| {
3455                if collaborators.contains(&connection_id) {
3456                    None
3457                } else {
3458                    Some(connection_id)
3459                }
3460            });
3461
3462    broadcast(None, non_collaborators, |peer_id| {
3463        session.peer.send(
3464            peer_id,
3465            proto::UpdateChannels {
3466                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3467                    channel_id: channel_id.to_proto(),
3468                    epoch: epoch as u64,
3469                    version: version.clone(),
3470                }],
3471                ..Default::default()
3472            },
3473        )
3474    });
3475
3476    Ok(())
3477}
3478
3479/// Rejoin the channel notes after a connection blip
3480async fn rejoin_channel_buffers(
3481    request: proto::RejoinChannelBuffers,
3482    response: Response<proto::RejoinChannelBuffers>,
3483    session: MessageContext,
3484) -> Result<()> {
3485    let db = session.db().await;
3486    let buffers = db
3487        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3488        .await?;
3489
3490    for rejoined_buffer in &buffers {
3491        let collaborators_to_notify = rejoined_buffer
3492            .buffer
3493            .collaborators
3494            .iter()
3495            .filter_map(|c| Some(c.peer_id?.into()));
3496        channel_buffer_updated(
3497            session.connection_id,
3498            collaborators_to_notify,
3499            &proto::UpdateChannelBufferCollaborators {
3500                channel_id: rejoined_buffer.buffer.channel_id,
3501                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3502            },
3503            &session.peer,
3504        );
3505    }
3506
3507    response.send(proto::RejoinChannelBuffersResponse {
3508        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3509    })?;
3510
3511    Ok(())
3512}
3513
3514/// Stop editing the channel notes
3515async fn leave_channel_buffer(
3516    request: proto::LeaveChannelBuffer,
3517    response: Response<proto::LeaveChannelBuffer>,
3518    session: MessageContext,
3519) -> Result<()> {
3520    let db = session.db().await;
3521    let channel_id = ChannelId::from_proto(request.channel_id);
3522
3523    let left_buffer = db
3524        .leave_channel_buffer(channel_id, session.connection_id)
3525        .await?;
3526
3527    response.send(Ack {})?;
3528
3529    channel_buffer_updated(
3530        session.connection_id,
3531        left_buffer.connections,
3532        &proto::UpdateChannelBufferCollaborators {
3533            channel_id: channel_id.to_proto(),
3534            collaborators: left_buffer.collaborators,
3535        },
3536        &session.peer,
3537    );
3538
3539    Ok(())
3540}
3541
3542fn channel_buffer_updated<T: EnvelopedMessage>(
3543    sender_id: ConnectionId,
3544    collaborators: impl IntoIterator<Item = ConnectionId>,
3545    message: &T,
3546    peer: &Peer,
3547) {
3548    broadcast(Some(sender_id), collaborators, |peer_id| {
3549        peer.send(peer_id, message.clone())
3550    });
3551}
3552
3553fn send_notifications(
3554    connection_pool: &ConnectionPool,
3555    peer: &Peer,
3556    notifications: db::NotificationBatch,
3557) {
3558    for (user_id, notification) in notifications {
3559        for connection_id in connection_pool.user_connection_ids(user_id) {
3560            if let Err(error) = peer.send(
3561                connection_id,
3562                proto::AddNotification {
3563                    notification: Some(notification.clone()),
3564                },
3565            ) {
3566                tracing::error!(
3567                    "failed to send notification to {:?} {}",
3568                    connection_id,
3569                    error
3570                );
3571            }
3572        }
3573    }
3574}
3575
3576/// Send a message to the channel
3577async fn send_channel_message(
3578    _request: proto::SendChannelMessage,
3579    _response: Response<proto::SendChannelMessage>,
3580    _session: MessageContext,
3581) -> Result<()> {
3582    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3583}
3584
3585/// Delete a channel message
3586async fn remove_channel_message(
3587    _request: proto::RemoveChannelMessage,
3588    _response: Response<proto::RemoveChannelMessage>,
3589    _session: MessageContext,
3590) -> Result<()> {
3591    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3592}
3593
3594async fn update_channel_message(
3595    _request: proto::UpdateChannelMessage,
3596    _response: Response<proto::UpdateChannelMessage>,
3597    _session: MessageContext,
3598) -> Result<()> {
3599    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3600}
3601
3602/// Mark a channel message as read
3603async fn acknowledge_channel_message(
3604    _request: proto::AckChannelMessage,
3605    _session: MessageContext,
3606) -> Result<()> {
3607    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3608}
3609
3610/// Mark a buffer version as synced
3611async fn acknowledge_buffer_version(
3612    request: proto::AckBufferOperation,
3613    session: MessageContext,
3614) -> Result<()> {
3615    let buffer_id = BufferId::from_proto(request.buffer_id);
3616    session
3617        .db()
3618        .await
3619        .observe_buffer_version(
3620            buffer_id,
3621            session.user_id(),
3622            request.epoch as i32,
3623            &request.version,
3624        )
3625        .await?;
3626    Ok(())
3627}
3628
3629/// Get a Supermaven API key for the user
3630async fn get_supermaven_api_key(
3631    _request: proto::GetSupermavenApiKey,
3632    response: Response<proto::GetSupermavenApiKey>,
3633    session: MessageContext,
3634) -> Result<()> {
3635    let user_id: String = session.user_id().to_string();
3636    if !session.is_staff() {
3637        return Err(anyhow!("supermaven not enabled for this account"))?;
3638    }
3639
3640    let email = session.email().context("user must have an email")?;
3641
3642    let supermaven_admin_api = session
3643        .supermaven_client
3644        .as_ref()
3645        .context("supermaven not configured")?;
3646
3647    let result = supermaven_admin_api
3648        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3649        .await?;
3650
3651    response.send(proto::GetSupermavenApiKeyResponse {
3652        api_key: result.api_key,
3653    })?;
3654
3655    Ok(())
3656}
3657
3658/// Start receiving chat updates for a channel
3659async fn join_channel_chat(
3660    _request: proto::JoinChannelChat,
3661    _response: Response<proto::JoinChannelChat>,
3662    _session: MessageContext,
3663) -> Result<()> {
3664    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3665}
3666
3667/// Stop receiving chat updates for a channel
3668async fn leave_channel_chat(
3669    _request: proto::LeaveChannelChat,
3670    _session: MessageContext,
3671) -> Result<()> {
3672    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3673}
3674
3675/// Retrieve the chat history for a channel
3676async fn get_channel_messages(
3677    _request: proto::GetChannelMessages,
3678    _response: Response<proto::GetChannelMessages>,
3679    _session: MessageContext,
3680) -> Result<()> {
3681    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3682}
3683
3684/// Retrieve specific chat messages
3685async fn get_channel_messages_by_id(
3686    _request: proto::GetChannelMessagesById,
3687    _response: Response<proto::GetChannelMessagesById>,
3688    _session: MessageContext,
3689) -> Result<()> {
3690    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3691}
3692
3693/// Retrieve the current users notifications
3694async fn get_notifications(
3695    request: proto::GetNotifications,
3696    response: Response<proto::GetNotifications>,
3697    session: MessageContext,
3698) -> Result<()> {
3699    let notifications = session
3700        .db()
3701        .await
3702        .get_notifications(
3703            session.user_id(),
3704            NOTIFICATION_COUNT_PER_PAGE,
3705            request.before_id.map(db::NotificationId::from_proto),
3706        )
3707        .await?;
3708    response.send(proto::GetNotificationsResponse {
3709        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3710        notifications,
3711    })?;
3712    Ok(())
3713}
3714
3715/// Mark notifications as read
3716async fn mark_notification_as_read(
3717    request: proto::MarkNotificationRead,
3718    response: Response<proto::MarkNotificationRead>,
3719    session: MessageContext,
3720) -> Result<()> {
3721    let database = &session.db().await;
3722    let notifications = database
3723        .mark_notification_as_read_by_id(
3724            session.user_id(),
3725            NotificationId::from_proto(request.notification_id),
3726        )
3727        .await?;
3728    send_notifications(
3729        &*session.connection_pool().await,
3730        &session.peer,
3731        notifications,
3732    );
3733    response.send(proto::Ack {})?;
3734    Ok(())
3735}
3736
3737fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3738    let message = match message {
3739        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3740        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3741        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3742        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3743        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3744            code: frame.code.into(),
3745            reason: frame.reason.as_str().to_owned().into(),
3746        })),
3747        // We should never receive a frame while reading the message, according
3748        // to the `tungstenite` maintainers:
3749        //
3750        // > It cannot occur when you read messages from the WebSocket, but it
3751        // > can be used when you want to send the raw frames (e.g. you want to
3752        // > send the frames to the WebSocket without composing the full message first).
3753        // >
3754        // > — https://github.com/snapview/tungstenite-rs/issues/268
3755        TungsteniteMessage::Frame(_) => {
3756            bail!("received an unexpected frame while reading the message")
3757        }
3758    };
3759
3760    Ok(message)
3761}
3762
3763fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3764    match message {
3765        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3766        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3767        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3768        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3769        AxumMessage::Close(frame) => {
3770            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3771                code: frame.code.into(),
3772                reason: frame.reason.as_ref().into(),
3773            }))
3774        }
3775    }
3776}
3777
3778fn notify_membership_updated(
3779    connection_pool: &mut ConnectionPool,
3780    result: MembershipUpdated,
3781    user_id: UserId,
3782    peer: &Peer,
3783) {
3784    for membership in &result.new_channels.channel_memberships {
3785        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3786    }
3787    for channel_id in &result.removed_channels {
3788        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3789    }
3790
3791    let user_channels_update = proto::UpdateUserChannels {
3792        channel_memberships: result
3793            .new_channels
3794            .channel_memberships
3795            .iter()
3796            .map(|cm| proto::ChannelMembership {
3797                channel_id: cm.channel_id.to_proto(),
3798                role: cm.role.into(),
3799            })
3800            .collect(),
3801        ..Default::default()
3802    };
3803
3804    let mut update = build_channels_update(result.new_channels);
3805    update.delete_channels = result
3806        .removed_channels
3807        .into_iter()
3808        .map(|id| id.to_proto())
3809        .collect();
3810    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3811
3812    for connection_id in connection_pool.user_connection_ids(user_id) {
3813        peer.send(connection_id, user_channels_update.clone())
3814            .trace_err();
3815        peer.send(connection_id, update.clone()).trace_err();
3816    }
3817}
3818
3819fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3820    proto::UpdateUserChannels {
3821        channel_memberships: channels
3822            .channel_memberships
3823            .iter()
3824            .map(|m| proto::ChannelMembership {
3825                channel_id: m.channel_id.to_proto(),
3826                role: m.role.into(),
3827            })
3828            .collect(),
3829        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3830    }
3831}
3832
3833fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3834    let mut update = proto::UpdateChannels::default();
3835
3836    for channel in channels.channels {
3837        update.channels.push(channel.to_proto());
3838    }
3839
3840    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3841
3842    for (channel_id, participants) in channels.channel_participants {
3843        update
3844            .channel_participants
3845            .push(proto::ChannelParticipants {
3846                channel_id: channel_id.to_proto(),
3847                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3848            });
3849    }
3850
3851    for channel in channels.invited_channels {
3852        update.channel_invitations.push(channel.to_proto());
3853    }
3854
3855    update
3856}
3857
3858fn build_initial_contacts_update(
3859    contacts: Vec<db::Contact>,
3860    pool: &ConnectionPool,
3861) -> proto::UpdateContacts {
3862    let mut update = proto::UpdateContacts::default();
3863
3864    for contact in contacts {
3865        match contact {
3866            db::Contact::Accepted { user_id, busy } => {
3867                update.contacts.push(contact_for_user(user_id, busy, pool));
3868            }
3869            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3870            db::Contact::Incoming { user_id } => {
3871                update
3872                    .incoming_requests
3873                    .push(proto::IncomingContactRequest {
3874                        requester_id: user_id.to_proto(),
3875                    })
3876            }
3877        }
3878    }
3879
3880    update
3881}
3882
3883fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3884    proto::Contact {
3885        user_id: user_id.to_proto(),
3886        online: pool.is_user_online(user_id),
3887        busy,
3888    }
3889}
3890
3891fn room_updated(room: &proto::Room, peer: &Peer) {
3892    broadcast(
3893        None,
3894        room.participants
3895            .iter()
3896            .filter_map(|participant| Some(participant.peer_id?.into())),
3897        |peer_id| {
3898            peer.send(
3899                peer_id,
3900                proto::RoomUpdated {
3901                    room: Some(room.clone()),
3902                },
3903            )
3904        },
3905    );
3906}
3907
3908fn channel_updated(
3909    channel: &db::channel::Model,
3910    room: &proto::Room,
3911    peer: &Peer,
3912    pool: &ConnectionPool,
3913) {
3914    let participants = room
3915        .participants
3916        .iter()
3917        .map(|p| p.user_id)
3918        .collect::<Vec<_>>();
3919
3920    broadcast(
3921        None,
3922        pool.channel_connection_ids(channel.root_id())
3923            .filter_map(|(channel_id, role)| {
3924                role.can_see_channel(channel.visibility)
3925                    .then_some(channel_id)
3926            }),
3927        |peer_id| {
3928            peer.send(
3929                peer_id,
3930                proto::UpdateChannels {
3931                    channel_participants: vec![proto::ChannelParticipants {
3932                        channel_id: channel.id.to_proto(),
3933                        participant_user_ids: participants.clone(),
3934                    }],
3935                    ..Default::default()
3936                },
3937            )
3938        },
3939    );
3940}
3941
3942async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3943    let db = session.db().await;
3944
3945    let contacts = db.get_contacts(user_id).await?;
3946    let busy = db.is_user_busy(user_id).await?;
3947
3948    let pool = session.connection_pool().await;
3949    let updated_contact = contact_for_user(user_id, busy, &pool);
3950    for contact in contacts {
3951        if let db::Contact::Accepted {
3952            user_id: contact_user_id,
3953            ..
3954        } = contact
3955        {
3956            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3957                session
3958                    .peer
3959                    .send(
3960                        contact_conn_id,
3961                        proto::UpdateContacts {
3962                            contacts: vec![updated_contact.clone()],
3963                            remove_contacts: Default::default(),
3964                            incoming_requests: Default::default(),
3965                            remove_incoming_requests: Default::default(),
3966                            outgoing_requests: Default::default(),
3967                            remove_outgoing_requests: Default::default(),
3968                        },
3969                    )
3970                    .trace_err();
3971            }
3972        }
3973    }
3974    Ok(())
3975}
3976
3977async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3978    let mut contacts_to_update = HashSet::default();
3979
3980    let room_id;
3981    let canceled_calls_to_user_ids;
3982    let livekit_room;
3983    let delete_livekit_room;
3984    let room;
3985    let channel;
3986
3987    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3988        contacts_to_update.insert(session.user_id());
3989
3990        for project in left_room.left_projects.values() {
3991            project_left(project, session);
3992        }
3993
3994        room_id = RoomId::from_proto(left_room.room.id);
3995        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3996        livekit_room = mem::take(&mut left_room.room.livekit_room);
3997        delete_livekit_room = left_room.deleted;
3998        room = mem::take(&mut left_room.room);
3999        channel = mem::take(&mut left_room.channel);
4000
4001        room_updated(&room, &session.peer);
4002    } else {
4003        return Ok(());
4004    }
4005
4006    if let Some(channel) = channel {
4007        channel_updated(
4008            &channel,
4009            &room,
4010            &session.peer,
4011            &*session.connection_pool().await,
4012        );
4013    }
4014
4015    {
4016        let pool = session.connection_pool().await;
4017        for canceled_user_id in canceled_calls_to_user_ids {
4018            for connection_id in pool.user_connection_ids(canceled_user_id) {
4019                session
4020                    .peer
4021                    .send(
4022                        connection_id,
4023                        proto::CallCanceled {
4024                            room_id: room_id.to_proto(),
4025                        },
4026                    )
4027                    .trace_err();
4028            }
4029            contacts_to_update.insert(canceled_user_id);
4030        }
4031    }
4032
4033    for contact_user_id in contacts_to_update {
4034        update_user_contacts(contact_user_id, session).await?;
4035    }
4036
4037    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4038        live_kit
4039            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4040            .await
4041            .trace_err();
4042
4043        if delete_livekit_room {
4044            live_kit.delete_room(livekit_room).await.trace_err();
4045        }
4046    }
4047
4048    Ok(())
4049}
4050
4051async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4052    let left_channel_buffers = session
4053        .db()
4054        .await
4055        .leave_channel_buffers(session.connection_id)
4056        .await?;
4057
4058    for left_buffer in left_channel_buffers {
4059        channel_buffer_updated(
4060            session.connection_id,
4061            left_buffer.connections,
4062            &proto::UpdateChannelBufferCollaborators {
4063                channel_id: left_buffer.channel_id.to_proto(),
4064                collaborators: left_buffer.collaborators,
4065            },
4066            &session.peer,
4067        );
4068    }
4069
4070    Ok(())
4071}
4072
4073fn project_left(project: &db::LeftProject, session: &Session) {
4074    for connection_id in &project.connection_ids {
4075        if project.should_unshare {
4076            session
4077                .peer
4078                .send(
4079                    *connection_id,
4080                    proto::UnshareProject {
4081                        project_id: project.id.to_proto(),
4082                    },
4083                )
4084                .trace_err();
4085        } else {
4086            session
4087                .peer
4088                .send(
4089                    *connection_id,
4090                    proto::RemoveProjectCollaborator {
4091                        project_id: project.id.to_proto(),
4092                        peer_id: Some(session.connection_id.into()),
4093                    },
4094                )
4095                .trace_err();
4096        }
4097    }
4098}
4099
4100pub trait ResultExt {
4101    type Ok;
4102
4103    fn trace_err(self) -> Option<Self::Ok>;
4104}
4105
4106impl<T, E> ResultExt for Result<T, E>
4107where
4108    E: std::fmt::Debug,
4109{
4110    type Ok = T;
4111
4112    #[track_caller]
4113    fn trace_err(self) -> Option<T> {
4114        match self {
4115            Ok(value) => Some(value),
4116            Err(error) => {
4117                tracing::error!("{:?}", error);
4118                None
4119            }
4120        }
4121    }
4122}