1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
  10    },
  11    executor::Executor,
  12};
  13use anyhow::{Context as _, anyhow, bail};
  14use async_tungstenite::tungstenite::{
  15    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  16};
  17use axum::headers::UserAgent;
  18use axum::{
  19    Extension, Router, TypedHeader,
  20    body::Body,
  21    extract::{
  22        ConnectInfo, WebSocketUpgrade,
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use core::fmt::{self, Debug, Formatter};
  34use futures::TryFutureExt as _;
  35use reqwest_client::ReqwestClient;
  36use rpc::proto::split_repository_update;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38use tracing::Span;
  39use util::paths::PathStyle;
  40
  41use futures::{
  42    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  43    stream::FuturesUnordered,
  44};
  45use prometheus::{IntGauge, register_int_gauge};
  46use rpc::{
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52};
  53use semantic_version::SemanticVersion;
  54use serde::{Serialize, Serializer};
  55use std::{
  56    any::TypeId,
  57    future::Future,
  58    marker::PhantomData,
  59    mem,
  60    net::SocketAddr,
  61    ops::{Deref, DerefMut},
  62    rc::Rc,
  63    sync::{
  64        Arc, OnceLock,
  65        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  66    },
  67    time::{Duration, Instant},
  68};
  69use tokio::sync::{Semaphore, watch};
  70use tower::ServiceBuilder;
  71use tracing::{
  72    Instrument,
  73    field::{self},
  74    info_span, instrument,
  75};
  76
  77pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  78
  79// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  80pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  81
  82const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  83const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  84
  85static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  86
  87const TOTAL_DURATION_MS: &str = "total_duration_ms";
  88const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  89const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  90const HOST_WAITING_MS: &str = "host_waiting_ms";
  91
  92type MessageHandler =
  93    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  94
  95pub struct ConnectionGuard;
  96
  97impl ConnectionGuard {
  98    pub fn try_acquire() -> Result<Self, ()> {
  99        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 100        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 101            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 102            tracing::error!(
 103                "too many concurrent connections: {}",
 104                current_connections + 1
 105            );
 106            return Err(());
 107        }
 108        Ok(ConnectionGuard)
 109    }
 110}
 111
 112impl Drop for ConnectionGuard {
 113    fn drop(&mut self) {
 114        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 115    }
 116}
 117
 118struct Response<R> {
 119    peer: Arc<Peer>,
 120    receipt: Receipt<R>,
 121    responded: Arc<AtomicBool>,
 122}
 123
 124impl<R: RequestMessage> Response<R> {
 125    fn send(self, payload: R::Response) -> Result<()> {
 126        self.responded.store(true, SeqCst);
 127        self.peer.respond(self.receipt, payload)?;
 128        Ok(())
 129    }
 130}
 131
 132#[derive(Clone, Debug)]
 133pub enum Principal {
 134    User(User),
 135    Impersonated { user: User, admin: User },
 136}
 137
 138impl Principal {
 139    fn update_span(&self, span: &tracing::Span) {
 140        match &self {
 141            Principal::User(user) => {
 142                span.record("user_id", user.id.0);
 143                span.record("login", &user.github_login);
 144            }
 145            Principal::Impersonated { user, admin } => {
 146                span.record("user_id", user.id.0);
 147                span.record("login", &user.github_login);
 148                span.record("impersonator", &admin.github_login);
 149            }
 150        }
 151    }
 152}
 153
 154#[derive(Clone)]
 155struct MessageContext {
 156    session: Session,
 157    span: tracing::Span,
 158}
 159
 160impl Deref for MessageContext {
 161    type Target = Session;
 162
 163    fn deref(&self) -> &Self::Target {
 164        &self.session
 165    }
 166}
 167
 168impl MessageContext {
 169    pub fn forward_request<T: RequestMessage>(
 170        &self,
 171        receiver_id: ConnectionId,
 172        request: T,
 173    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 174        let request_start_time = Instant::now();
 175        let span = self.span.clone();
 176        tracing::info!("start forwarding request");
 177        self.peer
 178            .forward_request(self.connection_id, receiver_id, request)
 179            .inspect(move |_| {
 180                span.record(
 181                    HOST_WAITING_MS,
 182                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 183                );
 184            })
 185            .inspect_err(|_| tracing::error!("error forwarding request"))
 186            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 187    }
 188}
 189
 190#[derive(Clone)]
 191struct Session {
 192    principal: Principal,
 193    connection_id: ConnectionId,
 194    db: Arc<tokio::sync::Mutex<DbHandle>>,
 195    peer: Arc<Peer>,
 196    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 197    app_state: Arc<AppState>,
 198    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 199    /// The GeoIP country code for the user.
 200    #[allow(unused)]
 201    geoip_country_code: Option<String>,
 202    #[allow(unused)]
 203    system_id: Option<String>,
 204    _executor: Executor,
 205}
 206
 207impl Session {
 208    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 209        #[cfg(test)]
 210        tokio::task::yield_now().await;
 211        let guard = self.db.lock().await;
 212        #[cfg(test)]
 213        tokio::task::yield_now().await;
 214        guard
 215    }
 216
 217    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 218        #[cfg(test)]
 219        tokio::task::yield_now().await;
 220        let guard = self.connection_pool.lock();
 221        ConnectionPoolGuard {
 222            guard,
 223            _not_send: PhantomData,
 224        }
 225    }
 226
 227    fn is_staff(&self) -> bool {
 228        match &self.principal {
 229            Principal::User(user) => user.admin,
 230            Principal::Impersonated { .. } => true,
 231        }
 232    }
 233
 234    fn user_id(&self) -> UserId {
 235        match &self.principal {
 236            Principal::User(user) => user.id,
 237            Principal::Impersonated { user, .. } => user.id,
 238        }
 239    }
 240
 241    pub fn email(&self) -> Option<String> {
 242        match &self.principal {
 243            Principal::User(user) => user.email_address.clone(),
 244            Principal::Impersonated { user, .. } => user.email_address.clone(),
 245        }
 246    }
 247}
 248
 249impl Debug for Session {
 250    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 251        let mut result = f.debug_struct("Session");
 252        match &self.principal {
 253            Principal::User(user) => {
 254                result.field("user", &user.github_login);
 255            }
 256            Principal::Impersonated { user, admin } => {
 257                result.field("user", &user.github_login);
 258                result.field("impersonator", &admin.github_login);
 259            }
 260        }
 261        result.field("connection_id", &self.connection_id).finish()
 262    }
 263}
 264
 265struct DbHandle(Arc<Database>);
 266
 267impl Deref for DbHandle {
 268    type Target = Database;
 269
 270    fn deref(&self) -> &Self::Target {
 271        self.0.as_ref()
 272    }
 273}
 274
 275pub struct Server {
 276    id: parking_lot::Mutex<ServerId>,
 277    peer: Arc<Peer>,
 278    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 279    app_state: Arc<AppState>,
 280    handlers: HashMap<TypeId, MessageHandler>,
 281    teardown: watch::Sender<bool>,
 282}
 283
 284pub(crate) struct ConnectionPoolGuard<'a> {
 285    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 286    _not_send: PhantomData<Rc<()>>,
 287}
 288
 289#[derive(Serialize)]
 290pub struct ServerSnapshot<'a> {
 291    peer: &'a Peer,
 292    #[serde(serialize_with = "serialize_deref")]
 293    connection_pool: ConnectionPoolGuard<'a>,
 294}
 295
 296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 297where
 298    S: Serializer,
 299    T: Deref<Target = U>,
 300    U: Serialize,
 301{
 302    Serialize::serialize(value.deref(), serializer)
 303}
 304
 305impl Server {
 306    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 307        let mut server = Self {
 308            id: parking_lot::Mutex::new(id),
 309            peer: Peer::new(id.0 as u32),
 310            app_state,
 311            connection_pool: Default::default(),
 312            handlers: Default::default(),
 313            teardown: watch::channel(false).0,
 314        };
 315
 316        server
 317            .add_request_handler(ping)
 318            .add_request_handler(create_room)
 319            .add_request_handler(join_room)
 320            .add_request_handler(rejoin_room)
 321            .add_request_handler(leave_room)
 322            .add_request_handler(set_room_participant_role)
 323            .add_request_handler(call)
 324            .add_request_handler(cancel_call)
 325            .add_message_handler(decline_call)
 326            .add_request_handler(update_participant_location)
 327            .add_request_handler(share_project)
 328            .add_message_handler(unshare_project)
 329            .add_request_handler(join_project)
 330            .add_message_handler(leave_project)
 331            .add_request_handler(update_project)
 332            .add_request_handler(update_worktree)
 333            .add_request_handler(update_repository)
 334            .add_request_handler(remove_repository)
 335            .add_message_handler(start_language_server)
 336            .add_message_handler(update_language_server)
 337            .add_message_handler(update_diagnostic_summary)
 338            .add_message_handler(update_worktree_settings)
 339            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 340            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 341            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 342            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 343            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 344            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 345            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 346            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 347            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 348            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 349            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 350            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 351            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 352            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 353            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 354            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 355            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 356            .add_request_handler(
 357                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 358            )
 359            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 360            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 361            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 362            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 363            .add_request_handler(
 364                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 365            )
 366            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 367            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 368            .add_request_handler(
 369                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 370            )
 371            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 372            .add_request_handler(
 373                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 374            )
 375            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 376            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 377            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 378            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 379            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 380            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 381            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 382            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 383            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 384            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 385            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 386            .add_request_handler(
 387                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 388            )
 389            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 390            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 391            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 392            .add_request_handler(lsp_query)
 393            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 394            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 395            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 396            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 397            .add_message_handler(create_buffer_for_peer)
 398            .add_request_handler(update_buffer)
 399            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 400            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 401            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 402            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 403            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 404            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 405            .add_message_handler(
 406                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 407            )
 408            .add_request_handler(get_users)
 409            .add_request_handler(fuzzy_search_users)
 410            .add_request_handler(request_contact)
 411            .add_request_handler(remove_contact)
 412            .add_request_handler(respond_to_contact_request)
 413            .add_message_handler(subscribe_to_channels)
 414            .add_request_handler(create_channel)
 415            .add_request_handler(delete_channel)
 416            .add_request_handler(invite_channel_member)
 417            .add_request_handler(remove_channel_member)
 418            .add_request_handler(set_channel_member_role)
 419            .add_request_handler(set_channel_visibility)
 420            .add_request_handler(rename_channel)
 421            .add_request_handler(join_channel_buffer)
 422            .add_request_handler(leave_channel_buffer)
 423            .add_message_handler(update_channel_buffer)
 424            .add_request_handler(rejoin_channel_buffers)
 425            .add_request_handler(get_channel_members)
 426            .add_request_handler(respond_to_channel_invite)
 427            .add_request_handler(join_channel)
 428            .add_request_handler(join_channel_chat)
 429            .add_message_handler(leave_channel_chat)
 430            .add_request_handler(send_channel_message)
 431            .add_request_handler(remove_channel_message)
 432            .add_request_handler(update_channel_message)
 433            .add_request_handler(get_channel_messages)
 434            .add_request_handler(get_channel_messages_by_id)
 435            .add_request_handler(get_notifications)
 436            .add_request_handler(mark_notification_as_read)
 437            .add_request_handler(move_channel)
 438            .add_request_handler(reorder_channel)
 439            .add_request_handler(follow)
 440            .add_message_handler(unfollow)
 441            .add_message_handler(update_followers)
 442            .add_message_handler(acknowledge_channel_message)
 443            .add_message_handler(acknowledge_buffer_version)
 444            .add_request_handler(get_supermaven_api_key)
 445            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 446            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 447            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 448            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 449            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 450            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 451            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 452            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 453            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 454            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 455            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 456            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 457            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 458            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 459            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 460            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 461            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 462            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 463            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 464            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 465            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 466            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 467            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 468            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 469            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 470            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 471            .add_message_handler(update_context)
 472            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 473            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
 474
 475        Arc::new(server)
 476    }
 477
 478    pub async fn start(&self) -> Result<()> {
 479        let server_id = *self.id.lock();
 480        let app_state = self.app_state.clone();
 481        let peer = self.peer.clone();
 482        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 483        let pool = self.connection_pool.clone();
 484        let livekit_client = self.app_state.livekit_client.clone();
 485
 486        let span = info_span!("start server");
 487        self.app_state.executor.spawn_detached(
 488            async move {
 489                tracing::info!("waiting for cleanup timeout");
 490                timeout.await;
 491                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 492
 493                app_state
 494                    .db
 495                    .delete_stale_channel_chat_participants(
 496                        &app_state.config.zed_environment,
 497                        server_id,
 498                    )
 499                    .await
 500                    .trace_err();
 501
 502                if let Some((room_ids, channel_ids)) = app_state
 503                    .db
 504                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 505                    .await
 506                    .trace_err()
 507                {
 508                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 509                    tracing::info!(
 510                        stale_channel_buffer_count = channel_ids.len(),
 511                        "retrieved stale channel buffers"
 512                    );
 513
 514                    for channel_id in channel_ids {
 515                        if let Some(refreshed_channel_buffer) = app_state
 516                            .db
 517                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 518                            .await
 519                            .trace_err()
 520                        {
 521                            for connection_id in refreshed_channel_buffer.connection_ids {
 522                                peer.send(
 523                                    connection_id,
 524                                    proto::UpdateChannelBufferCollaborators {
 525                                        channel_id: channel_id.to_proto(),
 526                                        collaborators: refreshed_channel_buffer
 527                                            .collaborators
 528                                            .clone(),
 529                                    },
 530                                )
 531                                .trace_err();
 532                            }
 533                        }
 534                    }
 535
 536                    for room_id in room_ids {
 537                        let mut contacts_to_update = HashSet::default();
 538                        let mut canceled_calls_to_user_ids = Vec::new();
 539                        let mut livekit_room = String::new();
 540                        let mut delete_livekit_room = false;
 541
 542                        if let Some(mut refreshed_room) = app_state
 543                            .db
 544                            .clear_stale_room_participants(room_id, server_id)
 545                            .await
 546                            .trace_err()
 547                        {
 548                            tracing::info!(
 549                                room_id = room_id.0,
 550                                new_participant_count = refreshed_room.room.participants.len(),
 551                                "refreshed room"
 552                            );
 553                            room_updated(&refreshed_room.room, &peer);
 554                            if let Some(channel) = refreshed_room.channel.as_ref() {
 555                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 556                            }
 557                            contacts_to_update
 558                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 559                            contacts_to_update
 560                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 561                            canceled_calls_to_user_ids =
 562                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 563                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 564                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 565                        }
 566
 567                        {
 568                            let pool = pool.lock();
 569                            for canceled_user_id in canceled_calls_to_user_ids {
 570                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 571                                    peer.send(
 572                                        connection_id,
 573                                        proto::CallCanceled {
 574                                            room_id: room_id.to_proto(),
 575                                        },
 576                                    )
 577                                    .trace_err();
 578                                }
 579                            }
 580                        }
 581
 582                        for user_id in contacts_to_update {
 583                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 584                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 585                            if let Some((busy, contacts)) = busy.zip(contacts) {
 586                                let pool = pool.lock();
 587                                let updated_contact = contact_for_user(user_id, busy, &pool);
 588                                for contact in contacts {
 589                                    if let db::Contact::Accepted {
 590                                        user_id: contact_user_id,
 591                                        ..
 592                                    } = contact
 593                                    {
 594                                        for contact_conn_id in
 595                                            pool.user_connection_ids(contact_user_id)
 596                                        {
 597                                            peer.send(
 598                                                contact_conn_id,
 599                                                proto::UpdateContacts {
 600                                                    contacts: vec![updated_contact.clone()],
 601                                                    remove_contacts: Default::default(),
 602                                                    incoming_requests: Default::default(),
 603                                                    remove_incoming_requests: Default::default(),
 604                                                    outgoing_requests: Default::default(),
 605                                                    remove_outgoing_requests: Default::default(),
 606                                                },
 607                                            )
 608                                            .trace_err();
 609                                        }
 610                                    }
 611                                }
 612                            }
 613                        }
 614
 615                        if let Some(live_kit) = livekit_client.as_ref()
 616                            && delete_livekit_room
 617                        {
 618                            live_kit.delete_room(livekit_room).await.trace_err();
 619                        }
 620                    }
 621                }
 622
 623                app_state
 624                    .db
 625                    .delete_stale_channel_chat_participants(
 626                        &app_state.config.zed_environment,
 627                        server_id,
 628                    )
 629                    .await
 630                    .trace_err();
 631
 632                app_state
 633                    .db
 634                    .clear_old_worktree_entries(server_id)
 635                    .await
 636                    .trace_err();
 637
 638                app_state
 639                    .db
 640                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 641                    .await
 642                    .trace_err();
 643            }
 644            .instrument(span),
 645        );
 646        Ok(())
 647    }
 648
 649    pub fn teardown(&self) {
 650        self.peer.teardown();
 651        self.connection_pool.lock().reset();
 652        let _ = self.teardown.send(true);
 653    }
 654
 655    #[cfg(test)]
 656    pub fn reset(&self, id: ServerId) {
 657        self.teardown();
 658        *self.id.lock() = id;
 659        self.peer.reset(id.0 as u32);
 660        let _ = self.teardown.send(false);
 661    }
 662
 663    #[cfg(test)]
 664    pub fn id(&self) -> ServerId {
 665        *self.id.lock()
 666    }
 667
 668    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 669    where
 670        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 671        Fut: 'static + Send + Future<Output = Result<()>>,
 672        M: EnvelopedMessage,
 673    {
 674        let prev_handler = self.handlers.insert(
 675            TypeId::of::<M>(),
 676            Box::new(move |envelope, session, span| {
 677                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 678                let received_at = envelope.received_at;
 679                tracing::info!("message received");
 680                let start_time = Instant::now();
 681                let future = (handler)(
 682                    *envelope,
 683                    MessageContext {
 684                        session,
 685                        span: span.clone(),
 686                    },
 687                );
 688                async move {
 689                    let result = future.await;
 690                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 691                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 692                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 693                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 694                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 695                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 696                    match result {
 697                        Err(error) => {
 698                            tracing::error!(?error, "error handling message")
 699                        }
 700                        Ok(()) => tracing::info!("finished handling message"),
 701                    }
 702                }
 703                .boxed()
 704            }),
 705        );
 706        if prev_handler.is_some() {
 707            panic!("registered a handler for the same message twice");
 708        }
 709        self
 710    }
 711
 712    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 713    where
 714        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 715        Fut: 'static + Send + Future<Output = Result<()>>,
 716        M: EnvelopedMessage,
 717    {
 718        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 719        self
 720    }
 721
 722    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 723    where
 724        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 725        Fut: Send + Future<Output = Result<()>>,
 726        M: RequestMessage,
 727    {
 728        let handler = Arc::new(handler);
 729        self.add_handler(move |envelope, session| {
 730            let receipt = envelope.receipt();
 731            let handler = handler.clone();
 732            async move {
 733                let peer = session.peer.clone();
 734                let responded = Arc::new(AtomicBool::default());
 735                let response = Response {
 736                    peer: peer.clone(),
 737                    responded: responded.clone(),
 738                    receipt,
 739                };
 740                match (handler)(envelope.payload, response, session).await {
 741                    Ok(()) => {
 742                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 743                            Ok(())
 744                        } else {
 745                            Err(anyhow!("handler did not send a response"))?
 746                        }
 747                    }
 748                    Err(error) => {
 749                        let proto_err = match &error {
 750                            Error::Internal(err) => err.to_proto(),
 751                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 752                        };
 753                        peer.respond_with_error(receipt, proto_err)?;
 754                        Err(error)
 755                    }
 756                }
 757            }
 758        })
 759    }
 760
 761    pub fn handle_connection(
 762        self: &Arc<Self>,
 763        connection: Connection,
 764        address: String,
 765        principal: Principal,
 766        zed_version: ZedVersion,
 767        release_channel: Option<String>,
 768        user_agent: Option<String>,
 769        geoip_country_code: Option<String>,
 770        system_id: Option<String>,
 771        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 772        executor: Executor,
 773        connection_guard: Option<ConnectionGuard>,
 774    ) -> impl Future<Output = ()> + use<> {
 775        let this = self.clone();
 776        let span = info_span!("handle connection", %address,
 777            connection_id=field::Empty,
 778            user_id=field::Empty,
 779            login=field::Empty,
 780            impersonator=field::Empty,
 781            user_agent=field::Empty,
 782            geoip_country_code=field::Empty,
 783            release_channel=field::Empty,
 784        );
 785        principal.update_span(&span);
 786        if let Some(user_agent) = user_agent {
 787            span.record("user_agent", user_agent);
 788        }
 789        if let Some(release_channel) = release_channel {
 790            span.record("release_channel", release_channel);
 791        }
 792
 793        if let Some(country_code) = geoip_country_code.as_ref() {
 794            span.record("geoip_country_code", country_code);
 795        }
 796
 797        let mut teardown = self.teardown.subscribe();
 798        async move {
 799            if *teardown.borrow() {
 800                tracing::error!("server is tearing down");
 801                return;
 802            }
 803
 804            let (connection_id, handle_io, mut incoming_rx) =
 805                this.peer.add_connection(connection, {
 806                    let executor = executor.clone();
 807                    move |duration| executor.sleep(duration)
 808                });
 809            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 810
 811            tracing::info!("connection opened");
 812
 813            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 814            let http_client = match ReqwestClient::user_agent(&user_agent) {
 815                Ok(http_client) => Arc::new(http_client),
 816                Err(error) => {
 817                    tracing::error!(?error, "failed to create HTTP client");
 818                    return;
 819                }
 820            };
 821
 822            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 823                |supermaven_admin_api_key| {
 824                    Arc::new(SupermavenAdminApi::new(
 825                        supermaven_admin_api_key.to_string(),
 826                        http_client.clone(),
 827                    ))
 828                },
 829            );
 830
 831            let session = Session {
 832                principal: principal.clone(),
 833                connection_id,
 834                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 835                peer: this.peer.clone(),
 836                connection_pool: this.connection_pool.clone(),
 837                app_state: this.app_state.clone(),
 838                geoip_country_code,
 839                system_id,
 840                _executor: executor.clone(),
 841                supermaven_client,
 842            };
 843
 844            if let Err(error) = this
 845                .send_initial_client_update(
 846                    connection_id,
 847                    zed_version,
 848                    send_connection_id,
 849                    &session,
 850                )
 851                .await
 852            {
 853                tracing::error!(?error, "failed to send initial client update");
 854                return;
 855            }
 856            drop(connection_guard);
 857
 858            let handle_io = handle_io.fuse();
 859            futures::pin_mut!(handle_io);
 860
 861            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 862            // This prevents deadlocks when e.g., client A performs a request to client B and
 863            // client B performs a request to client A. If both clients stop processing further
 864            // messages until their respective request completes, they won't have a chance to
 865            // respond to the other client's request and cause a deadlock.
 866            //
 867            // This arrangement ensures we will attempt to process earlier messages first, but fall
 868            // back to processing messages arrived later in the spirit of making progress.
 869            const MAX_CONCURRENT_HANDLERS: usize = 256;
 870            let mut foreground_message_handlers = FuturesUnordered::new();
 871            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 872            let get_concurrent_handlers = {
 873                let concurrent_handlers = concurrent_handlers.clone();
 874                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 875            };
 876            loop {
 877                let next_message = async {
 878                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 879                    let message = incoming_rx.next().await;
 880                    // Cache the concurrent_handlers here, so that we know what the
 881                    // queue looks like as each handler starts
 882                    (permit, message, get_concurrent_handlers())
 883                }
 884                .fuse();
 885                futures::pin_mut!(next_message);
 886                futures::select_biased! {
 887                    _ = teardown.changed().fuse() => return,
 888                    result = handle_io => {
 889                        if let Err(error) = result {
 890                            tracing::error!(?error, "error handling I/O");
 891                        }
 892                        break;
 893                    }
 894                    _ = foreground_message_handlers.next() => {}
 895                    next_message = next_message => {
 896                        let (permit, message, concurrent_handlers) = next_message;
 897                        if let Some(message) = message {
 898                            let type_name = message.payload_type_name();
 899                            // note: we copy all the fields from the parent span so we can query them in the logs.
 900                            // (https://github.com/tokio-rs/tracing/issues/2670).
 901                            let span = tracing::info_span!("receive message",
 902                                %connection_id,
 903                                %address,
 904                                type_name,
 905                                concurrent_handlers,
 906                                user_id=field::Empty,
 907                                login=field::Empty,
 908                                impersonator=field::Empty,
 909                                lsp_query_request=field::Empty,
 910                                release_channel=field::Empty,
 911                                { TOTAL_DURATION_MS }=field::Empty,
 912                                { PROCESSING_DURATION_MS }=field::Empty,
 913                                { QUEUE_DURATION_MS }=field::Empty,
 914                                { HOST_WAITING_MS }=field::Empty
 915                            );
 916                            principal.update_span(&span);
 917                            let span_enter = span.enter();
 918                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 919                                let is_background = message.is_background();
 920                                let handle_message = (handler)(message, session.clone(), span.clone());
 921                                drop(span_enter);
 922
 923                                let handle_message = async move {
 924                                    handle_message.await;
 925                                    drop(permit);
 926                                }.instrument(span);
 927                                if is_background {
 928                                    executor.spawn_detached(handle_message);
 929                                } else {
 930                                    foreground_message_handlers.push(handle_message);
 931                                }
 932                            } else {
 933                                tracing::error!("no message handler");
 934                            }
 935                        } else {
 936                            tracing::info!("connection closed");
 937                            break;
 938                        }
 939                    }
 940                }
 941            }
 942
 943            drop(foreground_message_handlers);
 944            let concurrent_handlers = get_concurrent_handlers();
 945            tracing::info!(concurrent_handlers, "signing out");
 946            if let Err(error) = connection_lost(session, teardown, executor).await {
 947                tracing::error!(?error, "error signing out");
 948            }
 949        }
 950        .instrument(span)
 951    }
 952
 953    async fn send_initial_client_update(
 954        &self,
 955        connection_id: ConnectionId,
 956        zed_version: ZedVersion,
 957        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 958        session: &Session,
 959    ) -> Result<()> {
 960        self.peer.send(
 961            connection_id,
 962            proto::Hello {
 963                peer_id: Some(connection_id.into()),
 964            },
 965        )?;
 966        tracing::info!("sent hello message");
 967        if let Some(send_connection_id) = send_connection_id.take() {
 968            let _ = send_connection_id.send(connection_id);
 969        }
 970
 971        match &session.principal {
 972            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 973                if !user.connected_once {
 974                    self.peer.send(connection_id, proto::ShowContacts {})?;
 975                    self.app_state
 976                        .db
 977                        .set_user_connected_once(user.id, true)
 978                        .await?;
 979                }
 980
 981                let contacts = self.app_state.db.get_contacts(user.id).await?;
 982
 983                {
 984                    let mut pool = self.connection_pool.lock();
 985                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 986                    self.peer.send(
 987                        connection_id,
 988                        build_initial_contacts_update(contacts, &pool),
 989                    )?;
 990                }
 991
 992                if should_auto_subscribe_to_channels(zed_version) {
 993                    subscribe_user_to_channels(user.id, session).await?;
 994                }
 995
 996                if let Some(incoming_call) =
 997                    self.app_state.db.incoming_call_for_user(user.id).await?
 998                {
 999                    self.peer.send(connection_id, incoming_call)?;
1000                }
1001
1002                update_user_contacts(user.id, session).await?;
1003            }
1004        }
1005
1006        Ok(())
1007    }
1008
1009    pub async fn invite_code_redeemed(
1010        self: &Arc<Self>,
1011        inviter_id: UserId,
1012        invitee_id: UserId,
1013    ) -> Result<()> {
1014        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1015            && let Some(code) = &user.invite_code
1016        {
1017            let pool = self.connection_pool.lock();
1018            let invitee_contact = contact_for_user(invitee_id, false, &pool);
1019            for connection_id in pool.user_connection_ids(inviter_id) {
1020                self.peer.send(
1021                    connection_id,
1022                    proto::UpdateContacts {
1023                        contacts: vec![invitee_contact.clone()],
1024                        ..Default::default()
1025                    },
1026                )?;
1027                self.peer.send(
1028                    connection_id,
1029                    proto::UpdateInviteInfo {
1030                        url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1031                        count: user.invite_count as u32,
1032                    },
1033                )?;
1034            }
1035        }
1036        Ok(())
1037    }
1038
1039    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1040        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1041            && let Some(invite_code) = &user.invite_code
1042        {
1043            let pool = self.connection_pool.lock();
1044            for connection_id in pool.user_connection_ids(user_id) {
1045                self.peer.send(
1046                    connection_id,
1047                    proto::UpdateInviteInfo {
1048                        url: format!(
1049                            "{}{}",
1050                            self.app_state.config.invite_link_prefix, invite_code
1051                        ),
1052                        count: user.invite_count as u32,
1053                    },
1054                )?;
1055            }
1056        }
1057        Ok(())
1058    }
1059
1060    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1061        ServerSnapshot {
1062            connection_pool: ConnectionPoolGuard {
1063                guard: self.connection_pool.lock(),
1064                _not_send: PhantomData,
1065            },
1066            peer: &self.peer,
1067        }
1068    }
1069}
1070
1071impl Deref for ConnectionPoolGuard<'_> {
1072    type Target = ConnectionPool;
1073
1074    fn deref(&self) -> &Self::Target {
1075        &self.guard
1076    }
1077}
1078
1079impl DerefMut for ConnectionPoolGuard<'_> {
1080    fn deref_mut(&mut self) -> &mut Self::Target {
1081        &mut self.guard
1082    }
1083}
1084
1085impl Drop for ConnectionPoolGuard<'_> {
1086    fn drop(&mut self) {
1087        #[cfg(test)]
1088        self.check_invariants();
1089    }
1090}
1091
1092fn broadcast<F>(
1093    sender_id: Option<ConnectionId>,
1094    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1095    mut f: F,
1096) where
1097    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1098{
1099    for receiver_id in receiver_ids {
1100        if Some(receiver_id) != sender_id
1101            && let Err(error) = f(receiver_id)
1102        {
1103            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1104        }
1105    }
1106}
1107
1108pub struct ProtocolVersion(u32);
1109
1110impl Header for ProtocolVersion {
1111    fn name() -> &'static HeaderName {
1112        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1113        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1114    }
1115
1116    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1117    where
1118        Self: Sized,
1119        I: Iterator<Item = &'i axum::http::HeaderValue>,
1120    {
1121        let version = values
1122            .next()
1123            .ok_or_else(axum::headers::Error::invalid)?
1124            .to_str()
1125            .map_err(|_| axum::headers::Error::invalid())?
1126            .parse()
1127            .map_err(|_| axum::headers::Error::invalid())?;
1128        Ok(Self(version))
1129    }
1130
1131    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1132        values.extend([self.0.to_string().parse().unwrap()]);
1133    }
1134}
1135
1136pub struct AppVersionHeader(SemanticVersion);
1137impl Header for AppVersionHeader {
1138    fn name() -> &'static HeaderName {
1139        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1140        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1141    }
1142
1143    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1144    where
1145        Self: Sized,
1146        I: Iterator<Item = &'i axum::http::HeaderValue>,
1147    {
1148        let version = values
1149            .next()
1150            .ok_or_else(axum::headers::Error::invalid)?
1151            .to_str()
1152            .map_err(|_| axum::headers::Error::invalid())?
1153            .parse()
1154            .map_err(|_| axum::headers::Error::invalid())?;
1155        Ok(Self(version))
1156    }
1157
1158    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1159        values.extend([self.0.to_string().parse().unwrap()]);
1160    }
1161}
1162
1163#[derive(Debug)]
1164pub struct ReleaseChannelHeader(String);
1165
1166impl Header for ReleaseChannelHeader {
1167    fn name() -> &'static HeaderName {
1168        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1169        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1170    }
1171
1172    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1173    where
1174        Self: Sized,
1175        I: Iterator<Item = &'i axum::http::HeaderValue>,
1176    {
1177        Ok(Self(
1178            values
1179                .next()
1180                .ok_or_else(axum::headers::Error::invalid)?
1181                .to_str()
1182                .map_err(|_| axum::headers::Error::invalid())?
1183                .to_owned(),
1184        ))
1185    }
1186
1187    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1188        values.extend([self.0.parse().unwrap()]);
1189    }
1190}
1191
1192pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1193    Router::new()
1194        .route("/rpc", get(handle_websocket_request))
1195        .layer(
1196            ServiceBuilder::new()
1197                .layer(Extension(server.app_state.clone()))
1198                .layer(middleware::from_fn(auth::validate_header)),
1199        )
1200        .route("/metrics", get(handle_metrics))
1201        .layer(Extension(server))
1202}
1203
1204pub async fn handle_websocket_request(
1205    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1206    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1207    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1208    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1209    Extension(server): Extension<Arc<Server>>,
1210    Extension(principal): Extension<Principal>,
1211    user_agent: Option<TypedHeader<UserAgent>>,
1212    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1213    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1214    ws: WebSocketUpgrade,
1215) -> axum::response::Response {
1216    if protocol_version != rpc::PROTOCOL_VERSION {
1217        return (
1218            StatusCode::UPGRADE_REQUIRED,
1219            "client must be upgraded".to_string(),
1220        )
1221            .into_response();
1222    }
1223
1224    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1225        return (
1226            StatusCode::UPGRADE_REQUIRED,
1227            "no version header found".to_string(),
1228        )
1229            .into_response();
1230    };
1231
1232    let release_channel = release_channel_header.map(|header| header.0.0);
1233
1234    if !version.can_collaborate() {
1235        return (
1236            StatusCode::UPGRADE_REQUIRED,
1237            "client must be upgraded".to_string(),
1238        )
1239            .into_response();
1240    }
1241
1242    let socket_address = socket_address.to_string();
1243
1244    // Acquire connection guard before WebSocket upgrade
1245    let connection_guard = match ConnectionGuard::try_acquire() {
1246        Ok(guard) => guard,
1247        Err(()) => {
1248            return (
1249                StatusCode::SERVICE_UNAVAILABLE,
1250                "Too many concurrent connections",
1251            )
1252                .into_response();
1253        }
1254    };
1255
1256    ws.on_upgrade(move |socket| {
1257        let socket = socket
1258            .map_ok(to_tungstenite_message)
1259            .err_into()
1260            .with(|message| async move { to_axum_message(message) });
1261        let connection = Connection::new(Box::pin(socket));
1262        async move {
1263            server
1264                .handle_connection(
1265                    connection,
1266                    socket_address,
1267                    principal,
1268                    version,
1269                    release_channel,
1270                    user_agent.map(|header| header.to_string()),
1271                    country_code_header.map(|header| header.to_string()),
1272                    system_id_header.map(|header| header.to_string()),
1273                    None,
1274                    Executor::Production,
1275                    Some(connection_guard),
1276                )
1277                .await;
1278        }
1279    })
1280}
1281
1282pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1283    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1284    let connections_metric = CONNECTIONS_METRIC
1285        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1286
1287    let connections = server
1288        .connection_pool
1289        .lock()
1290        .connections()
1291        .filter(|connection| !connection.admin)
1292        .count();
1293    connections_metric.set(connections as _);
1294
1295    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1296    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1297        register_int_gauge!(
1298            "shared_projects",
1299            "number of open projects with one or more guests"
1300        )
1301        .unwrap()
1302    });
1303
1304    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1305    shared_projects_metric.set(shared_projects as _);
1306
1307    let encoder = prometheus::TextEncoder::new();
1308    let metric_families = prometheus::gather();
1309    let encoded_metrics = encoder
1310        .encode_to_string(&metric_families)
1311        .map_err(|err| anyhow!("{err}"))?;
1312    Ok(encoded_metrics)
1313}
1314
1315#[instrument(err, skip(executor))]
1316async fn connection_lost(
1317    session: Session,
1318    mut teardown: watch::Receiver<bool>,
1319    executor: Executor,
1320) -> Result<()> {
1321    session.peer.disconnect(session.connection_id);
1322    session
1323        .connection_pool()
1324        .await
1325        .remove_connection(session.connection_id)?;
1326
1327    session
1328        .db()
1329        .await
1330        .connection_lost(session.connection_id)
1331        .await
1332        .trace_err();
1333
1334    futures::select_biased! {
1335        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1336
1337            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1338            leave_room_for_session(&session, session.connection_id).await.trace_err();
1339            leave_channel_buffers_for_session(&session)
1340                .await
1341                .trace_err();
1342
1343            if !session
1344                .connection_pool()
1345                .await
1346                .is_user_online(session.user_id())
1347            {
1348                let db = session.db().await;
1349                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1350                    room_updated(&room, &session.peer);
1351                }
1352            }
1353
1354            update_user_contacts(session.user_id(), &session).await?;
1355        },
1356        _ = teardown.changed().fuse() => {}
1357    }
1358
1359    Ok(())
1360}
1361
1362/// Acknowledges a ping from a client, used to keep the connection alive.
1363async fn ping(
1364    _: proto::Ping,
1365    response: Response<proto::Ping>,
1366    _session: MessageContext,
1367) -> Result<()> {
1368    response.send(proto::Ack {})?;
1369    Ok(())
1370}
1371
1372/// Creates a new room for calling (outside of channels)
1373async fn create_room(
1374    _request: proto::CreateRoom,
1375    response: Response<proto::CreateRoom>,
1376    session: MessageContext,
1377) -> Result<()> {
1378    let livekit_room = nanoid::nanoid!(30);
1379
1380    let live_kit_connection_info = util::maybe!(async {
1381        let live_kit = session.app_state.livekit_client.as_ref();
1382        let live_kit = live_kit?;
1383        let user_id = session.user_id().to_string();
1384
1385        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1386
1387        Some(proto::LiveKitConnectionInfo {
1388            server_url: live_kit.url().into(),
1389            token,
1390            can_publish: true,
1391        })
1392    })
1393    .await;
1394
1395    let room = session
1396        .db()
1397        .await
1398        .create_room(session.user_id(), session.connection_id, &livekit_room)
1399        .await?;
1400
1401    response.send(proto::CreateRoomResponse {
1402        room: Some(room.clone()),
1403        live_kit_connection_info,
1404    })?;
1405
1406    update_user_contacts(session.user_id(), &session).await?;
1407    Ok(())
1408}
1409
1410/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1411async fn join_room(
1412    request: proto::JoinRoom,
1413    response: Response<proto::JoinRoom>,
1414    session: MessageContext,
1415) -> Result<()> {
1416    let room_id = RoomId::from_proto(request.id);
1417
1418    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1419
1420    if let Some(channel_id) = channel_id {
1421        return join_channel_internal(channel_id, Box::new(response), session).await;
1422    }
1423
1424    let joined_room = {
1425        let room = session
1426            .db()
1427            .await
1428            .join_room(room_id, session.user_id(), session.connection_id)
1429            .await?;
1430        room_updated(&room.room, &session.peer);
1431        room.into_inner()
1432    };
1433
1434    for connection_id in session
1435        .connection_pool()
1436        .await
1437        .user_connection_ids(session.user_id())
1438    {
1439        session
1440            .peer
1441            .send(
1442                connection_id,
1443                proto::CallCanceled {
1444                    room_id: room_id.to_proto(),
1445                },
1446            )
1447            .trace_err();
1448    }
1449
1450    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1451    {
1452        live_kit
1453            .room_token(
1454                &joined_room.room.livekit_room,
1455                &session.user_id().to_string(),
1456            )
1457            .trace_err()
1458            .map(|token| proto::LiveKitConnectionInfo {
1459                server_url: live_kit.url().into(),
1460                token,
1461                can_publish: true,
1462            })
1463    } else {
1464        None
1465    };
1466
1467    response.send(proto::JoinRoomResponse {
1468        room: Some(joined_room.room),
1469        channel_id: None,
1470        live_kit_connection_info,
1471    })?;
1472
1473    update_user_contacts(session.user_id(), &session).await?;
1474    Ok(())
1475}
1476
1477/// Rejoin room is used to reconnect to a room after connection errors.
1478async fn rejoin_room(
1479    request: proto::RejoinRoom,
1480    response: Response<proto::RejoinRoom>,
1481    session: MessageContext,
1482) -> Result<()> {
1483    let room;
1484    let channel;
1485    {
1486        let mut rejoined_room = session
1487            .db()
1488            .await
1489            .rejoin_room(request, session.user_id(), session.connection_id)
1490            .await?;
1491
1492        response.send(proto::RejoinRoomResponse {
1493            room: Some(rejoined_room.room.clone()),
1494            reshared_projects: rejoined_room
1495                .reshared_projects
1496                .iter()
1497                .map(|project| proto::ResharedProject {
1498                    id: project.id.to_proto(),
1499                    collaborators: project
1500                        .collaborators
1501                        .iter()
1502                        .map(|collaborator| collaborator.to_proto())
1503                        .collect(),
1504                })
1505                .collect(),
1506            rejoined_projects: rejoined_room
1507                .rejoined_projects
1508                .iter()
1509                .map(|rejoined_project| rejoined_project.to_proto())
1510                .collect(),
1511        })?;
1512        room_updated(&rejoined_room.room, &session.peer);
1513
1514        for project in &rejoined_room.reshared_projects {
1515            for collaborator in &project.collaborators {
1516                session
1517                    .peer
1518                    .send(
1519                        collaborator.connection_id,
1520                        proto::UpdateProjectCollaborator {
1521                            project_id: project.id.to_proto(),
1522                            old_peer_id: Some(project.old_connection_id.into()),
1523                            new_peer_id: Some(session.connection_id.into()),
1524                        },
1525                    )
1526                    .trace_err();
1527            }
1528
1529            broadcast(
1530                Some(session.connection_id),
1531                project
1532                    .collaborators
1533                    .iter()
1534                    .map(|collaborator| collaborator.connection_id),
1535                |connection_id| {
1536                    session.peer.forward_send(
1537                        session.connection_id,
1538                        connection_id,
1539                        proto::UpdateProject {
1540                            project_id: project.id.to_proto(),
1541                            worktrees: project.worktrees.clone(),
1542                        },
1543                    )
1544                },
1545            );
1546        }
1547
1548        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1549
1550        let rejoined_room = rejoined_room.into_inner();
1551
1552        room = rejoined_room.room;
1553        channel = rejoined_room.channel;
1554    }
1555
1556    if let Some(channel) = channel {
1557        channel_updated(
1558            &channel,
1559            &room,
1560            &session.peer,
1561            &*session.connection_pool().await,
1562        );
1563    }
1564
1565    update_user_contacts(session.user_id(), &session).await?;
1566    Ok(())
1567}
1568
1569fn notify_rejoined_projects(
1570    rejoined_projects: &mut Vec<RejoinedProject>,
1571    session: &Session,
1572) -> Result<()> {
1573    for project in rejoined_projects.iter() {
1574        for collaborator in &project.collaborators {
1575            session
1576                .peer
1577                .send(
1578                    collaborator.connection_id,
1579                    proto::UpdateProjectCollaborator {
1580                        project_id: project.id.to_proto(),
1581                        old_peer_id: Some(project.old_connection_id.into()),
1582                        new_peer_id: Some(session.connection_id.into()),
1583                    },
1584                )
1585                .trace_err();
1586        }
1587    }
1588
1589    for project in rejoined_projects {
1590        for worktree in mem::take(&mut project.worktrees) {
1591            // Stream this worktree's entries.
1592            let message = proto::UpdateWorktree {
1593                project_id: project.id.to_proto(),
1594                worktree_id: worktree.id,
1595                abs_path: worktree.abs_path.clone(),
1596                root_name: worktree.root_name,
1597                updated_entries: worktree.updated_entries,
1598                removed_entries: worktree.removed_entries,
1599                scan_id: worktree.scan_id,
1600                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1601                updated_repositories: worktree.updated_repositories,
1602                removed_repositories: worktree.removed_repositories,
1603            };
1604            for update in proto::split_worktree_update(message) {
1605                session.peer.send(session.connection_id, update)?;
1606            }
1607
1608            // Stream this worktree's diagnostics.
1609            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1610            if let Some(summary) = worktree_diagnostics.next() {
1611                let message = proto::UpdateDiagnosticSummary {
1612                    project_id: project.id.to_proto(),
1613                    worktree_id: worktree.id,
1614                    summary: Some(summary),
1615                    more_summaries: worktree_diagnostics.collect(),
1616                };
1617                session.peer.send(session.connection_id, message)?;
1618            }
1619
1620            for settings_file in worktree.settings_files {
1621                session.peer.send(
1622                    session.connection_id,
1623                    proto::UpdateWorktreeSettings {
1624                        project_id: project.id.to_proto(),
1625                        worktree_id: worktree.id,
1626                        path: settings_file.path,
1627                        content: Some(settings_file.content),
1628                        kind: Some(settings_file.kind.to_proto().into()),
1629                    },
1630                )?;
1631            }
1632        }
1633
1634        for repository in mem::take(&mut project.updated_repositories) {
1635            for update in split_repository_update(repository) {
1636                session.peer.send(session.connection_id, update)?;
1637            }
1638        }
1639
1640        for id in mem::take(&mut project.removed_repositories) {
1641            session.peer.send(
1642                session.connection_id,
1643                proto::RemoveRepository {
1644                    project_id: project.id.to_proto(),
1645                    id,
1646                },
1647            )?;
1648        }
1649    }
1650
1651    Ok(())
1652}
1653
1654/// leave room disconnects from the room.
1655async fn leave_room(
1656    _: proto::LeaveRoom,
1657    response: Response<proto::LeaveRoom>,
1658    session: MessageContext,
1659) -> Result<()> {
1660    leave_room_for_session(&session, session.connection_id).await?;
1661    response.send(proto::Ack {})?;
1662    Ok(())
1663}
1664
1665/// Updates the permissions of someone else in the room.
1666async fn set_room_participant_role(
1667    request: proto::SetRoomParticipantRole,
1668    response: Response<proto::SetRoomParticipantRole>,
1669    session: MessageContext,
1670) -> Result<()> {
1671    let user_id = UserId::from_proto(request.user_id);
1672    let role = ChannelRole::from(request.role());
1673
1674    let (livekit_room, can_publish) = {
1675        let room = session
1676            .db()
1677            .await
1678            .set_room_participant_role(
1679                session.user_id(),
1680                RoomId::from_proto(request.room_id),
1681                user_id,
1682                role,
1683            )
1684            .await?;
1685
1686        let livekit_room = room.livekit_room.clone();
1687        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1688        room_updated(&room, &session.peer);
1689        (livekit_room, can_publish)
1690    };
1691
1692    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1693        live_kit
1694            .update_participant(
1695                livekit_room.clone(),
1696                request.user_id.to_string(),
1697                livekit_api::proto::ParticipantPermission {
1698                    can_subscribe: true,
1699                    can_publish,
1700                    can_publish_data: can_publish,
1701                    hidden: false,
1702                    recorder: false,
1703                },
1704            )
1705            .await
1706            .trace_err();
1707    }
1708
1709    response.send(proto::Ack {})?;
1710    Ok(())
1711}
1712
1713/// Call someone else into the current room
1714async fn call(
1715    request: proto::Call,
1716    response: Response<proto::Call>,
1717    session: MessageContext,
1718) -> Result<()> {
1719    let room_id = RoomId::from_proto(request.room_id);
1720    let calling_user_id = session.user_id();
1721    let calling_connection_id = session.connection_id;
1722    let called_user_id = UserId::from_proto(request.called_user_id);
1723    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1724    if !session
1725        .db()
1726        .await
1727        .has_contact(calling_user_id, called_user_id)
1728        .await?
1729    {
1730        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1731    }
1732
1733    let incoming_call = {
1734        let (room, incoming_call) = &mut *session
1735            .db()
1736            .await
1737            .call(
1738                room_id,
1739                calling_user_id,
1740                calling_connection_id,
1741                called_user_id,
1742                initial_project_id,
1743            )
1744            .await?;
1745        room_updated(room, &session.peer);
1746        mem::take(incoming_call)
1747    };
1748    update_user_contacts(called_user_id, &session).await?;
1749
1750    let mut calls = session
1751        .connection_pool()
1752        .await
1753        .user_connection_ids(called_user_id)
1754        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1755        .collect::<FuturesUnordered<_>>();
1756
1757    while let Some(call_response) = calls.next().await {
1758        match call_response.as_ref() {
1759            Ok(_) => {
1760                response.send(proto::Ack {})?;
1761                return Ok(());
1762            }
1763            Err(_) => {
1764                call_response.trace_err();
1765            }
1766        }
1767    }
1768
1769    {
1770        let room = session
1771            .db()
1772            .await
1773            .call_failed(room_id, called_user_id)
1774            .await?;
1775        room_updated(&room, &session.peer);
1776    }
1777    update_user_contacts(called_user_id, &session).await?;
1778
1779    Err(anyhow!("failed to ring user"))?
1780}
1781
1782/// Cancel an outgoing call.
1783async fn cancel_call(
1784    request: proto::CancelCall,
1785    response: Response<proto::CancelCall>,
1786    session: MessageContext,
1787) -> Result<()> {
1788    let called_user_id = UserId::from_proto(request.called_user_id);
1789    let room_id = RoomId::from_proto(request.room_id);
1790    {
1791        let room = session
1792            .db()
1793            .await
1794            .cancel_call(room_id, session.connection_id, called_user_id)
1795            .await?;
1796        room_updated(&room, &session.peer);
1797    }
1798
1799    for connection_id in session
1800        .connection_pool()
1801        .await
1802        .user_connection_ids(called_user_id)
1803    {
1804        session
1805            .peer
1806            .send(
1807                connection_id,
1808                proto::CallCanceled {
1809                    room_id: room_id.to_proto(),
1810                },
1811            )
1812            .trace_err();
1813    }
1814    response.send(proto::Ack {})?;
1815
1816    update_user_contacts(called_user_id, &session).await?;
1817    Ok(())
1818}
1819
1820/// Decline an incoming call.
1821async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1822    let room_id = RoomId::from_proto(message.room_id);
1823    {
1824        let room = session
1825            .db()
1826            .await
1827            .decline_call(Some(room_id), session.user_id())
1828            .await?
1829            .context("declining call")?;
1830        room_updated(&room, &session.peer);
1831    }
1832
1833    for connection_id in session
1834        .connection_pool()
1835        .await
1836        .user_connection_ids(session.user_id())
1837    {
1838        session
1839            .peer
1840            .send(
1841                connection_id,
1842                proto::CallCanceled {
1843                    room_id: room_id.to_proto(),
1844                },
1845            )
1846            .trace_err();
1847    }
1848    update_user_contacts(session.user_id(), &session).await?;
1849    Ok(())
1850}
1851
1852/// Updates other participants in the room with your current location.
1853async fn update_participant_location(
1854    request: proto::UpdateParticipantLocation,
1855    response: Response<proto::UpdateParticipantLocation>,
1856    session: MessageContext,
1857) -> Result<()> {
1858    let room_id = RoomId::from_proto(request.room_id);
1859    let location = request.location.context("invalid location")?;
1860
1861    let db = session.db().await;
1862    let room = db
1863        .update_room_participant_location(room_id, session.connection_id, location)
1864        .await?;
1865
1866    room_updated(&room, &session.peer);
1867    response.send(proto::Ack {})?;
1868    Ok(())
1869}
1870
1871/// Share a project into the room.
1872async fn share_project(
1873    request: proto::ShareProject,
1874    response: Response<proto::ShareProject>,
1875    session: MessageContext,
1876) -> Result<()> {
1877    let (project_id, room) = &*session
1878        .db()
1879        .await
1880        .share_project(
1881            RoomId::from_proto(request.room_id),
1882            session.connection_id,
1883            &request.worktrees,
1884            request.is_ssh_project,
1885            request.windows_paths.unwrap_or(false),
1886        )
1887        .await?;
1888    response.send(proto::ShareProjectResponse {
1889        project_id: project_id.to_proto(),
1890    })?;
1891    room_updated(room, &session.peer);
1892
1893    Ok(())
1894}
1895
1896/// Unshare a project from the room.
1897async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1898    let project_id = ProjectId::from_proto(message.project_id);
1899    unshare_project_internal(project_id, session.connection_id, &session).await
1900}
1901
1902async fn unshare_project_internal(
1903    project_id: ProjectId,
1904    connection_id: ConnectionId,
1905    session: &Session,
1906) -> Result<()> {
1907    let delete = {
1908        let room_guard = session
1909            .db()
1910            .await
1911            .unshare_project(project_id, connection_id)
1912            .await?;
1913
1914        let (delete, room, guest_connection_ids) = &*room_guard;
1915
1916        let message = proto::UnshareProject {
1917            project_id: project_id.to_proto(),
1918        };
1919
1920        broadcast(
1921            Some(connection_id),
1922            guest_connection_ids.iter().copied(),
1923            |conn_id| session.peer.send(conn_id, message.clone()),
1924        );
1925        if let Some(room) = room {
1926            room_updated(room, &session.peer);
1927        }
1928
1929        *delete
1930    };
1931
1932    if delete {
1933        let db = session.db().await;
1934        db.delete_project(project_id).await?;
1935    }
1936
1937    Ok(())
1938}
1939
1940/// Join someone elses shared project.
1941async fn join_project(
1942    request: proto::JoinProject,
1943    response: Response<proto::JoinProject>,
1944    session: MessageContext,
1945) -> Result<()> {
1946    let project_id = ProjectId::from_proto(request.project_id);
1947
1948    tracing::info!(%project_id, "join project");
1949
1950    let db = session.db().await;
1951    let (project, replica_id) = &mut *db
1952        .join_project(
1953            project_id,
1954            session.connection_id,
1955            session.user_id(),
1956            request.committer_name.clone(),
1957            request.committer_email.clone(),
1958        )
1959        .await?;
1960    drop(db);
1961    tracing::info!(%project_id, "join remote project");
1962    let collaborators = project
1963        .collaborators
1964        .iter()
1965        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1966        .map(|collaborator| collaborator.to_proto())
1967        .collect::<Vec<_>>();
1968    let project_id = project.id;
1969    let guest_user_id = session.user_id();
1970
1971    let worktrees = project
1972        .worktrees
1973        .iter()
1974        .map(|(id, worktree)| proto::WorktreeMetadata {
1975            id: *id,
1976            root_name: worktree.root_name.clone(),
1977            visible: worktree.visible,
1978            abs_path: worktree.abs_path.clone(),
1979        })
1980        .collect::<Vec<_>>();
1981
1982    let add_project_collaborator = proto::AddProjectCollaborator {
1983        project_id: project_id.to_proto(),
1984        collaborator: Some(proto::Collaborator {
1985            peer_id: Some(session.connection_id.into()),
1986            replica_id: replica_id.0 as u32,
1987            user_id: guest_user_id.to_proto(),
1988            is_host: false,
1989            committer_name: request.committer_name.clone(),
1990            committer_email: request.committer_email.clone(),
1991        }),
1992    };
1993
1994    for collaborator in &collaborators {
1995        session
1996            .peer
1997            .send(
1998                collaborator.peer_id.unwrap().into(),
1999                add_project_collaborator.clone(),
2000            )
2001            .trace_err();
2002    }
2003
2004    // First, we send the metadata associated with each worktree.
2005    let (language_servers, language_server_capabilities) = project
2006        .language_servers
2007        .clone()
2008        .into_iter()
2009        .map(|server| (server.server, server.capabilities))
2010        .unzip();
2011    response.send(proto::JoinProjectResponse {
2012        project_id: project.id.0 as u64,
2013        worktrees,
2014        replica_id: replica_id.0 as u32,
2015        collaborators,
2016        language_servers,
2017        language_server_capabilities,
2018        role: project.role.into(),
2019        windows_paths: project.path_style == PathStyle::Windows,
2020    })?;
2021
2022    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2023        // Stream this worktree's entries.
2024        let message = proto::UpdateWorktree {
2025            project_id: project_id.to_proto(),
2026            worktree_id,
2027            abs_path: worktree.abs_path.clone(),
2028            root_name: worktree.root_name,
2029            updated_entries: worktree.entries,
2030            removed_entries: Default::default(),
2031            scan_id: worktree.scan_id,
2032            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2033            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2034            removed_repositories: Default::default(),
2035        };
2036        for update in proto::split_worktree_update(message) {
2037            session.peer.send(session.connection_id, update.clone())?;
2038        }
2039
2040        // Stream this worktree's diagnostics.
2041        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2042        if let Some(summary) = worktree_diagnostics.next() {
2043            let message = proto::UpdateDiagnosticSummary {
2044                project_id: project.id.to_proto(),
2045                worktree_id: worktree.id,
2046                summary: Some(summary),
2047                more_summaries: worktree_diagnostics.collect(),
2048            };
2049            session.peer.send(session.connection_id, message)?;
2050        }
2051
2052        for settings_file in worktree.settings_files {
2053            session.peer.send(
2054                session.connection_id,
2055                proto::UpdateWorktreeSettings {
2056                    project_id: project_id.to_proto(),
2057                    worktree_id: worktree.id,
2058                    path: settings_file.path,
2059                    content: Some(settings_file.content),
2060                    kind: Some(settings_file.kind.to_proto() as i32),
2061                },
2062            )?;
2063        }
2064    }
2065
2066    for repository in mem::take(&mut project.repositories) {
2067        for update in split_repository_update(repository) {
2068            session.peer.send(session.connection_id, update)?;
2069        }
2070    }
2071
2072    for language_server in &project.language_servers {
2073        session.peer.send(
2074            session.connection_id,
2075            proto::UpdateLanguageServer {
2076                project_id: project_id.to_proto(),
2077                server_name: Some(language_server.server.name.clone()),
2078                language_server_id: language_server.server.id,
2079                variant: Some(
2080                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2081                        proto::LspDiskBasedDiagnosticsUpdated {},
2082                    ),
2083                ),
2084            },
2085        )?;
2086    }
2087
2088    Ok(())
2089}
2090
2091/// Leave someone elses shared project.
2092async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2093    let sender_id = session.connection_id;
2094    let project_id = ProjectId::from_proto(request.project_id);
2095    let db = session.db().await;
2096
2097    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2098    tracing::info!(
2099        %project_id,
2100        "leave project"
2101    );
2102
2103    project_left(project, &session);
2104    if let Some(room) = room {
2105        room_updated(room, &session.peer);
2106    }
2107
2108    Ok(())
2109}
2110
2111/// Updates other participants with changes to the project
2112async fn update_project(
2113    request: proto::UpdateProject,
2114    response: Response<proto::UpdateProject>,
2115    session: MessageContext,
2116) -> Result<()> {
2117    let project_id = ProjectId::from_proto(request.project_id);
2118    let (room, guest_connection_ids) = &*session
2119        .db()
2120        .await
2121        .update_project(project_id, session.connection_id, &request.worktrees)
2122        .await?;
2123    broadcast(
2124        Some(session.connection_id),
2125        guest_connection_ids.iter().copied(),
2126        |connection_id| {
2127            session
2128                .peer
2129                .forward_send(session.connection_id, connection_id, request.clone())
2130        },
2131    );
2132    if let Some(room) = room {
2133        room_updated(room, &session.peer);
2134    }
2135    response.send(proto::Ack {})?;
2136
2137    Ok(())
2138}
2139
2140/// Updates other participants with changes to the worktree
2141async fn update_worktree(
2142    request: proto::UpdateWorktree,
2143    response: Response<proto::UpdateWorktree>,
2144    session: MessageContext,
2145) -> Result<()> {
2146    let guest_connection_ids = session
2147        .db()
2148        .await
2149        .update_worktree(&request, session.connection_id)
2150        .await?;
2151
2152    broadcast(
2153        Some(session.connection_id),
2154        guest_connection_ids.iter().copied(),
2155        |connection_id| {
2156            session
2157                .peer
2158                .forward_send(session.connection_id, connection_id, request.clone())
2159        },
2160    );
2161    response.send(proto::Ack {})?;
2162    Ok(())
2163}
2164
2165async fn update_repository(
2166    request: proto::UpdateRepository,
2167    response: Response<proto::UpdateRepository>,
2168    session: MessageContext,
2169) -> Result<()> {
2170    let guest_connection_ids = session
2171        .db()
2172        .await
2173        .update_repository(&request, session.connection_id)
2174        .await?;
2175
2176    broadcast(
2177        Some(session.connection_id),
2178        guest_connection_ids.iter().copied(),
2179        |connection_id| {
2180            session
2181                .peer
2182                .forward_send(session.connection_id, connection_id, request.clone())
2183        },
2184    );
2185    response.send(proto::Ack {})?;
2186    Ok(())
2187}
2188
2189async fn remove_repository(
2190    request: proto::RemoveRepository,
2191    response: Response<proto::RemoveRepository>,
2192    session: MessageContext,
2193) -> Result<()> {
2194    let guest_connection_ids = session
2195        .db()
2196        .await
2197        .remove_repository(&request, session.connection_id)
2198        .await?;
2199
2200    broadcast(
2201        Some(session.connection_id),
2202        guest_connection_ids.iter().copied(),
2203        |connection_id| {
2204            session
2205                .peer
2206                .forward_send(session.connection_id, connection_id, request.clone())
2207        },
2208    );
2209    response.send(proto::Ack {})?;
2210    Ok(())
2211}
2212
2213/// Updates other participants with changes to the diagnostics
2214async fn update_diagnostic_summary(
2215    message: proto::UpdateDiagnosticSummary,
2216    session: MessageContext,
2217) -> Result<()> {
2218    let guest_connection_ids = session
2219        .db()
2220        .await
2221        .update_diagnostic_summary(&message, session.connection_id)
2222        .await?;
2223
2224    broadcast(
2225        Some(session.connection_id),
2226        guest_connection_ids.iter().copied(),
2227        |connection_id| {
2228            session
2229                .peer
2230                .forward_send(session.connection_id, connection_id, message.clone())
2231        },
2232    );
2233
2234    Ok(())
2235}
2236
2237/// Updates other participants with changes to the worktree settings
2238async fn update_worktree_settings(
2239    message: proto::UpdateWorktreeSettings,
2240    session: MessageContext,
2241) -> Result<()> {
2242    let guest_connection_ids = session
2243        .db()
2244        .await
2245        .update_worktree_settings(&message, session.connection_id)
2246        .await?;
2247
2248    broadcast(
2249        Some(session.connection_id),
2250        guest_connection_ids.iter().copied(),
2251        |connection_id| {
2252            session
2253                .peer
2254                .forward_send(session.connection_id, connection_id, message.clone())
2255        },
2256    );
2257
2258    Ok(())
2259}
2260
2261/// Notify other participants that a language server has started.
2262async fn start_language_server(
2263    request: proto::StartLanguageServer,
2264    session: MessageContext,
2265) -> Result<()> {
2266    let guest_connection_ids = session
2267        .db()
2268        .await
2269        .start_language_server(&request, session.connection_id)
2270        .await?;
2271
2272    broadcast(
2273        Some(session.connection_id),
2274        guest_connection_ids.iter().copied(),
2275        |connection_id| {
2276            session
2277                .peer
2278                .forward_send(session.connection_id, connection_id, request.clone())
2279        },
2280    );
2281    Ok(())
2282}
2283
2284/// Notify other participants that a language server has changed.
2285async fn update_language_server(
2286    request: proto::UpdateLanguageServer,
2287    session: MessageContext,
2288) -> Result<()> {
2289    let project_id = ProjectId::from_proto(request.project_id);
2290    let db = session.db().await;
2291
2292    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2293        && let Some(capabilities) = update.capabilities.clone()
2294    {
2295        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2296            .await?;
2297    }
2298
2299    let project_connection_ids = db
2300        .project_connection_ids(project_id, session.connection_id, true)
2301        .await?;
2302    broadcast(
2303        Some(session.connection_id),
2304        project_connection_ids.iter().copied(),
2305        |connection_id| {
2306            session
2307                .peer
2308                .forward_send(session.connection_id, connection_id, request.clone())
2309        },
2310    );
2311    Ok(())
2312}
2313
2314/// forward a project request to the host. These requests should be read only
2315/// as guests are allowed to send them.
2316async fn forward_read_only_project_request<T>(
2317    request: T,
2318    response: Response<T>,
2319    session: MessageContext,
2320) -> Result<()>
2321where
2322    T: EntityMessage + RequestMessage,
2323{
2324    let project_id = ProjectId::from_proto(request.remote_entity_id());
2325    let host_connection_id = session
2326        .db()
2327        .await
2328        .host_for_read_only_project_request(project_id, session.connection_id)
2329        .await?;
2330    let payload = session.forward_request(host_connection_id, request).await?;
2331    response.send(payload)?;
2332    Ok(())
2333}
2334
2335/// forward a project request to the host. These requests are disallowed
2336/// for guests.
2337async fn forward_mutating_project_request<T>(
2338    request: T,
2339    response: Response<T>,
2340    session: MessageContext,
2341) -> Result<()>
2342where
2343    T: EntityMessage + RequestMessage,
2344{
2345    let project_id = ProjectId::from_proto(request.remote_entity_id());
2346
2347    let host_connection_id = session
2348        .db()
2349        .await
2350        .host_for_mutating_project_request(project_id, session.connection_id)
2351        .await?;
2352    let payload = session.forward_request(host_connection_id, request).await?;
2353    response.send(payload)?;
2354    Ok(())
2355}
2356
2357async fn lsp_query(
2358    request: proto::LspQuery,
2359    response: Response<proto::LspQuery>,
2360    session: MessageContext,
2361) -> Result<()> {
2362    let (name, should_write) = request.query_name_and_write_permissions();
2363    tracing::Span::current().record("lsp_query_request", name);
2364    tracing::info!("lsp_query message received");
2365    if should_write {
2366        forward_mutating_project_request(request, response, session).await
2367    } else {
2368        forward_read_only_project_request(request, response, session).await
2369    }
2370}
2371
2372/// Notify other participants that a new buffer has been created
2373async fn create_buffer_for_peer(
2374    request: proto::CreateBufferForPeer,
2375    session: MessageContext,
2376) -> Result<()> {
2377    session
2378        .db()
2379        .await
2380        .check_user_is_project_host(
2381            ProjectId::from_proto(request.project_id),
2382            session.connection_id,
2383        )
2384        .await?;
2385    let peer_id = request.peer_id.context("invalid peer id")?;
2386    session
2387        .peer
2388        .forward_send(session.connection_id, peer_id.into(), request)?;
2389    Ok(())
2390}
2391
2392/// Notify other participants that a buffer has been updated. This is
2393/// allowed for guests as long as the update is limited to selections.
2394async fn update_buffer(
2395    request: proto::UpdateBuffer,
2396    response: Response<proto::UpdateBuffer>,
2397    session: MessageContext,
2398) -> Result<()> {
2399    let project_id = ProjectId::from_proto(request.project_id);
2400    let mut capability = Capability::ReadOnly;
2401
2402    for op in request.operations.iter() {
2403        match op.variant {
2404            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2405            Some(_) => capability = Capability::ReadWrite,
2406        }
2407    }
2408
2409    let host = {
2410        let guard = session
2411            .db()
2412            .await
2413            .connections_for_buffer_update(project_id, session.connection_id, capability)
2414            .await?;
2415
2416        let (host, guests) = &*guard;
2417
2418        broadcast(
2419            Some(session.connection_id),
2420            guests.clone(),
2421            |connection_id| {
2422                session
2423                    .peer
2424                    .forward_send(session.connection_id, connection_id, request.clone())
2425            },
2426        );
2427
2428        *host
2429    };
2430
2431    if host != session.connection_id {
2432        session.forward_request(host, request.clone()).await?;
2433    }
2434
2435    response.send(proto::Ack {})?;
2436    Ok(())
2437}
2438
2439async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2440    let project_id = ProjectId::from_proto(message.project_id);
2441
2442    let operation = message.operation.as_ref().context("invalid operation")?;
2443    let capability = match operation.variant.as_ref() {
2444        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2445            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2446                match buffer_op.variant {
2447                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2448                        Capability::ReadOnly
2449                    }
2450                    _ => Capability::ReadWrite,
2451                }
2452            } else {
2453                Capability::ReadWrite
2454            }
2455        }
2456        Some(_) => Capability::ReadWrite,
2457        None => Capability::ReadOnly,
2458    };
2459
2460    let guard = session
2461        .db()
2462        .await
2463        .connections_for_buffer_update(project_id, session.connection_id, capability)
2464        .await?;
2465
2466    let (host, guests) = &*guard;
2467
2468    broadcast(
2469        Some(session.connection_id),
2470        guests.iter().chain([host]).copied(),
2471        |connection_id| {
2472            session
2473                .peer
2474                .forward_send(session.connection_id, connection_id, message.clone())
2475        },
2476    );
2477
2478    Ok(())
2479}
2480
2481/// Notify other participants that a project has been updated.
2482async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2483    request: T,
2484    session: MessageContext,
2485) -> Result<()> {
2486    let project_id = ProjectId::from_proto(request.remote_entity_id());
2487    let project_connection_ids = session
2488        .db()
2489        .await
2490        .project_connection_ids(project_id, session.connection_id, false)
2491        .await?;
2492
2493    broadcast(
2494        Some(session.connection_id),
2495        project_connection_ids.iter().copied(),
2496        |connection_id| {
2497            session
2498                .peer
2499                .forward_send(session.connection_id, connection_id, request.clone())
2500        },
2501    );
2502    Ok(())
2503}
2504
2505/// Start following another user in a call.
2506async fn follow(
2507    request: proto::Follow,
2508    response: Response<proto::Follow>,
2509    session: MessageContext,
2510) -> Result<()> {
2511    let room_id = RoomId::from_proto(request.room_id);
2512    let project_id = request.project_id.map(ProjectId::from_proto);
2513    let leader_id = request.leader_id.context("invalid leader id")?.into();
2514    let follower_id = session.connection_id;
2515
2516    session
2517        .db()
2518        .await
2519        .check_room_participants(room_id, leader_id, session.connection_id)
2520        .await?;
2521
2522    let response_payload = session.forward_request(leader_id, request).await?;
2523    response.send(response_payload)?;
2524
2525    if let Some(project_id) = project_id {
2526        let room = session
2527            .db()
2528            .await
2529            .follow(room_id, project_id, leader_id, follower_id)
2530            .await?;
2531        room_updated(&room, &session.peer);
2532    }
2533
2534    Ok(())
2535}
2536
2537/// Stop following another user in a call.
2538async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2539    let room_id = RoomId::from_proto(request.room_id);
2540    let project_id = request.project_id.map(ProjectId::from_proto);
2541    let leader_id = request.leader_id.context("invalid leader id")?.into();
2542    let follower_id = session.connection_id;
2543
2544    session
2545        .db()
2546        .await
2547        .check_room_participants(room_id, leader_id, session.connection_id)
2548        .await?;
2549
2550    session
2551        .peer
2552        .forward_send(session.connection_id, leader_id, request)?;
2553
2554    if let Some(project_id) = project_id {
2555        let room = session
2556            .db()
2557            .await
2558            .unfollow(room_id, project_id, leader_id, follower_id)
2559            .await?;
2560        room_updated(&room, &session.peer);
2561    }
2562
2563    Ok(())
2564}
2565
2566/// Notify everyone following you of your current location.
2567async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2568    let room_id = RoomId::from_proto(request.room_id);
2569    let database = session.db.lock().await;
2570
2571    let connection_ids = if let Some(project_id) = request.project_id {
2572        let project_id = ProjectId::from_proto(project_id);
2573        database
2574            .project_connection_ids(project_id, session.connection_id, true)
2575            .await?
2576    } else {
2577        database
2578            .room_connection_ids(room_id, session.connection_id)
2579            .await?
2580    };
2581
2582    // For now, don't send view update messages back to that view's current leader.
2583    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2584        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2585        _ => None,
2586    });
2587
2588    for connection_id in connection_ids.iter().cloned() {
2589        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2590            session
2591                .peer
2592                .forward_send(session.connection_id, connection_id, request.clone())?;
2593        }
2594    }
2595    Ok(())
2596}
2597
2598/// Get public data about users.
2599async fn get_users(
2600    request: proto::GetUsers,
2601    response: Response<proto::GetUsers>,
2602    session: MessageContext,
2603) -> Result<()> {
2604    let user_ids = request
2605        .user_ids
2606        .into_iter()
2607        .map(UserId::from_proto)
2608        .collect();
2609    let users = session
2610        .db()
2611        .await
2612        .get_users_by_ids(user_ids)
2613        .await?
2614        .into_iter()
2615        .map(|user| proto::User {
2616            id: user.id.to_proto(),
2617            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2618            github_login: user.github_login,
2619            name: user.name,
2620        })
2621        .collect();
2622    response.send(proto::UsersResponse { users })?;
2623    Ok(())
2624}
2625
2626/// Search for users (to invite) buy Github login
2627async fn fuzzy_search_users(
2628    request: proto::FuzzySearchUsers,
2629    response: Response<proto::FuzzySearchUsers>,
2630    session: MessageContext,
2631) -> Result<()> {
2632    let query = request.query;
2633    let users = match query.len() {
2634        0 => vec![],
2635        1 | 2 => session
2636            .db()
2637            .await
2638            .get_user_by_github_login(&query)
2639            .await?
2640            .into_iter()
2641            .collect(),
2642        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2643    };
2644    let users = users
2645        .into_iter()
2646        .filter(|user| user.id != session.user_id())
2647        .map(|user| proto::User {
2648            id: user.id.to_proto(),
2649            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2650            github_login: user.github_login,
2651            name: user.name,
2652        })
2653        .collect();
2654    response.send(proto::UsersResponse { users })?;
2655    Ok(())
2656}
2657
2658/// Send a contact request to another user.
2659async fn request_contact(
2660    request: proto::RequestContact,
2661    response: Response<proto::RequestContact>,
2662    session: MessageContext,
2663) -> Result<()> {
2664    let requester_id = session.user_id();
2665    let responder_id = UserId::from_proto(request.responder_id);
2666    if requester_id == responder_id {
2667        return Err(anyhow!("cannot add yourself as a contact"))?;
2668    }
2669
2670    let notifications = session
2671        .db()
2672        .await
2673        .send_contact_request(requester_id, responder_id)
2674        .await?;
2675
2676    // Update outgoing contact requests of requester
2677    let mut update = proto::UpdateContacts::default();
2678    update.outgoing_requests.push(responder_id.to_proto());
2679    for connection_id in session
2680        .connection_pool()
2681        .await
2682        .user_connection_ids(requester_id)
2683    {
2684        session.peer.send(connection_id, update.clone())?;
2685    }
2686
2687    // Update incoming contact requests of responder
2688    let mut update = proto::UpdateContacts::default();
2689    update
2690        .incoming_requests
2691        .push(proto::IncomingContactRequest {
2692            requester_id: requester_id.to_proto(),
2693        });
2694    let connection_pool = session.connection_pool().await;
2695    for connection_id in connection_pool.user_connection_ids(responder_id) {
2696        session.peer.send(connection_id, update.clone())?;
2697    }
2698
2699    send_notifications(&connection_pool, &session.peer, notifications);
2700
2701    response.send(proto::Ack {})?;
2702    Ok(())
2703}
2704
2705/// Accept or decline a contact request
2706async fn respond_to_contact_request(
2707    request: proto::RespondToContactRequest,
2708    response: Response<proto::RespondToContactRequest>,
2709    session: MessageContext,
2710) -> Result<()> {
2711    let responder_id = session.user_id();
2712    let requester_id = UserId::from_proto(request.requester_id);
2713    let db = session.db().await;
2714    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2715        db.dismiss_contact_notification(responder_id, requester_id)
2716            .await?;
2717    } else {
2718        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2719
2720        let notifications = db
2721            .respond_to_contact_request(responder_id, requester_id, accept)
2722            .await?;
2723        let requester_busy = db.is_user_busy(requester_id).await?;
2724        let responder_busy = db.is_user_busy(responder_id).await?;
2725
2726        let pool = session.connection_pool().await;
2727        // Update responder with new contact
2728        let mut update = proto::UpdateContacts::default();
2729        if accept {
2730            update
2731                .contacts
2732                .push(contact_for_user(requester_id, requester_busy, &pool));
2733        }
2734        update
2735            .remove_incoming_requests
2736            .push(requester_id.to_proto());
2737        for connection_id in pool.user_connection_ids(responder_id) {
2738            session.peer.send(connection_id, update.clone())?;
2739        }
2740
2741        // Update requester with new contact
2742        let mut update = proto::UpdateContacts::default();
2743        if accept {
2744            update
2745                .contacts
2746                .push(contact_for_user(responder_id, responder_busy, &pool));
2747        }
2748        update
2749            .remove_outgoing_requests
2750            .push(responder_id.to_proto());
2751
2752        for connection_id in pool.user_connection_ids(requester_id) {
2753            session.peer.send(connection_id, update.clone())?;
2754        }
2755
2756        send_notifications(&pool, &session.peer, notifications);
2757    }
2758
2759    response.send(proto::Ack {})?;
2760    Ok(())
2761}
2762
2763/// Remove a contact.
2764async fn remove_contact(
2765    request: proto::RemoveContact,
2766    response: Response<proto::RemoveContact>,
2767    session: MessageContext,
2768) -> Result<()> {
2769    let requester_id = session.user_id();
2770    let responder_id = UserId::from_proto(request.user_id);
2771    let db = session.db().await;
2772    let (contact_accepted, deleted_notification_id) =
2773        db.remove_contact(requester_id, responder_id).await?;
2774
2775    let pool = session.connection_pool().await;
2776    // Update outgoing contact requests of requester
2777    let mut update = proto::UpdateContacts::default();
2778    if contact_accepted {
2779        update.remove_contacts.push(responder_id.to_proto());
2780    } else {
2781        update
2782            .remove_outgoing_requests
2783            .push(responder_id.to_proto());
2784    }
2785    for connection_id in pool.user_connection_ids(requester_id) {
2786        session.peer.send(connection_id, update.clone())?;
2787    }
2788
2789    // Update incoming contact requests of responder
2790    let mut update = proto::UpdateContacts::default();
2791    if contact_accepted {
2792        update.remove_contacts.push(requester_id.to_proto());
2793    } else {
2794        update
2795            .remove_incoming_requests
2796            .push(requester_id.to_proto());
2797    }
2798    for connection_id in pool.user_connection_ids(responder_id) {
2799        session.peer.send(connection_id, update.clone())?;
2800        if let Some(notification_id) = deleted_notification_id {
2801            session.peer.send(
2802                connection_id,
2803                proto::DeleteNotification {
2804                    notification_id: notification_id.to_proto(),
2805                },
2806            )?;
2807        }
2808    }
2809
2810    response.send(proto::Ack {})?;
2811    Ok(())
2812}
2813
2814fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2815    version.0.minor() < 139
2816}
2817
2818async fn subscribe_to_channels(
2819    _: proto::SubscribeToChannels,
2820    session: MessageContext,
2821) -> Result<()> {
2822    subscribe_user_to_channels(session.user_id(), &session).await?;
2823    Ok(())
2824}
2825
2826async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2827    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2828    let mut pool = session.connection_pool().await;
2829    for membership in &channels_for_user.channel_memberships {
2830        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2831    }
2832    session.peer.send(
2833        session.connection_id,
2834        build_update_user_channels(&channels_for_user),
2835    )?;
2836    session.peer.send(
2837        session.connection_id,
2838        build_channels_update(channels_for_user),
2839    )?;
2840    Ok(())
2841}
2842
2843/// Creates a new channel.
2844async fn create_channel(
2845    request: proto::CreateChannel,
2846    response: Response<proto::CreateChannel>,
2847    session: MessageContext,
2848) -> Result<()> {
2849    let db = session.db().await;
2850
2851    let parent_id = request.parent_id.map(ChannelId::from_proto);
2852    let (channel, membership) = db
2853        .create_channel(&request.name, parent_id, session.user_id())
2854        .await?;
2855
2856    let root_id = channel.root_id();
2857    let channel = Channel::from_model(channel);
2858
2859    response.send(proto::CreateChannelResponse {
2860        channel: Some(channel.to_proto()),
2861        parent_id: request.parent_id,
2862    })?;
2863
2864    let mut connection_pool = session.connection_pool().await;
2865    if let Some(membership) = membership {
2866        connection_pool.subscribe_to_channel(
2867            membership.user_id,
2868            membership.channel_id,
2869            membership.role,
2870        );
2871        let update = proto::UpdateUserChannels {
2872            channel_memberships: vec![proto::ChannelMembership {
2873                channel_id: membership.channel_id.to_proto(),
2874                role: membership.role.into(),
2875            }],
2876            ..Default::default()
2877        };
2878        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2879            session.peer.send(connection_id, update.clone())?;
2880        }
2881    }
2882
2883    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2884        if !role.can_see_channel(channel.visibility) {
2885            continue;
2886        }
2887
2888        let update = proto::UpdateChannels {
2889            channels: vec![channel.to_proto()],
2890            ..Default::default()
2891        };
2892        session.peer.send(connection_id, update.clone())?;
2893    }
2894
2895    Ok(())
2896}
2897
2898/// Delete a channel
2899async fn delete_channel(
2900    request: proto::DeleteChannel,
2901    response: Response<proto::DeleteChannel>,
2902    session: MessageContext,
2903) -> Result<()> {
2904    let db = session.db().await;
2905
2906    let channel_id = request.channel_id;
2907    let (root_channel, removed_channels) = db
2908        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2909        .await?;
2910    response.send(proto::Ack {})?;
2911
2912    // Notify members of removed channels
2913    let mut update = proto::UpdateChannels::default();
2914    update
2915        .delete_channels
2916        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2917
2918    let connection_pool = session.connection_pool().await;
2919    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2920        session.peer.send(connection_id, update.clone())?;
2921    }
2922
2923    Ok(())
2924}
2925
2926/// Invite someone to join a channel.
2927async fn invite_channel_member(
2928    request: proto::InviteChannelMember,
2929    response: Response<proto::InviteChannelMember>,
2930    session: MessageContext,
2931) -> Result<()> {
2932    let db = session.db().await;
2933    let channel_id = ChannelId::from_proto(request.channel_id);
2934    let invitee_id = UserId::from_proto(request.user_id);
2935    let InviteMemberResult {
2936        channel,
2937        notifications,
2938    } = db
2939        .invite_channel_member(
2940            channel_id,
2941            invitee_id,
2942            session.user_id(),
2943            request.role().into(),
2944        )
2945        .await?;
2946
2947    let update = proto::UpdateChannels {
2948        channel_invitations: vec![channel.to_proto()],
2949        ..Default::default()
2950    };
2951
2952    let connection_pool = session.connection_pool().await;
2953    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2954        session.peer.send(connection_id, update.clone())?;
2955    }
2956
2957    send_notifications(&connection_pool, &session.peer, notifications);
2958
2959    response.send(proto::Ack {})?;
2960    Ok(())
2961}
2962
2963/// remove someone from a channel
2964async fn remove_channel_member(
2965    request: proto::RemoveChannelMember,
2966    response: Response<proto::RemoveChannelMember>,
2967    session: MessageContext,
2968) -> Result<()> {
2969    let db = session.db().await;
2970    let channel_id = ChannelId::from_proto(request.channel_id);
2971    let member_id = UserId::from_proto(request.user_id);
2972
2973    let RemoveChannelMemberResult {
2974        membership_update,
2975        notification_id,
2976    } = db
2977        .remove_channel_member(channel_id, member_id, session.user_id())
2978        .await?;
2979
2980    let mut connection_pool = session.connection_pool().await;
2981    notify_membership_updated(
2982        &mut connection_pool,
2983        membership_update,
2984        member_id,
2985        &session.peer,
2986    );
2987    for connection_id in connection_pool.user_connection_ids(member_id) {
2988        if let Some(notification_id) = notification_id {
2989            session
2990                .peer
2991                .send(
2992                    connection_id,
2993                    proto::DeleteNotification {
2994                        notification_id: notification_id.to_proto(),
2995                    },
2996                )
2997                .trace_err();
2998        }
2999    }
3000
3001    response.send(proto::Ack {})?;
3002    Ok(())
3003}
3004
3005/// Toggle the channel between public and private.
3006/// Care is taken to maintain the invariant that public channels only descend from public channels,
3007/// (though members-only channels can appear at any point in the hierarchy).
3008async fn set_channel_visibility(
3009    request: proto::SetChannelVisibility,
3010    response: Response<proto::SetChannelVisibility>,
3011    session: MessageContext,
3012) -> Result<()> {
3013    let db = session.db().await;
3014    let channel_id = ChannelId::from_proto(request.channel_id);
3015    let visibility = request.visibility().into();
3016
3017    let channel_model = db
3018        .set_channel_visibility(channel_id, visibility, session.user_id())
3019        .await?;
3020    let root_id = channel_model.root_id();
3021    let channel = Channel::from_model(channel_model);
3022
3023    let mut connection_pool = session.connection_pool().await;
3024    for (user_id, role) in connection_pool
3025        .channel_user_ids(root_id)
3026        .collect::<Vec<_>>()
3027        .into_iter()
3028    {
3029        let update = if role.can_see_channel(channel.visibility) {
3030            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3031            proto::UpdateChannels {
3032                channels: vec![channel.to_proto()],
3033                ..Default::default()
3034            }
3035        } else {
3036            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3037            proto::UpdateChannels {
3038                delete_channels: vec![channel.id.to_proto()],
3039                ..Default::default()
3040            }
3041        };
3042
3043        for connection_id in connection_pool.user_connection_ids(user_id) {
3044            session.peer.send(connection_id, update.clone())?;
3045        }
3046    }
3047
3048    response.send(proto::Ack {})?;
3049    Ok(())
3050}
3051
3052/// Alter the role for a user in the channel.
3053async fn set_channel_member_role(
3054    request: proto::SetChannelMemberRole,
3055    response: Response<proto::SetChannelMemberRole>,
3056    session: MessageContext,
3057) -> Result<()> {
3058    let db = session.db().await;
3059    let channel_id = ChannelId::from_proto(request.channel_id);
3060    let member_id = UserId::from_proto(request.user_id);
3061    let result = db
3062        .set_channel_member_role(
3063            channel_id,
3064            session.user_id(),
3065            member_id,
3066            request.role().into(),
3067        )
3068        .await?;
3069
3070    match result {
3071        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3072            let mut connection_pool = session.connection_pool().await;
3073            notify_membership_updated(
3074                &mut connection_pool,
3075                membership_update,
3076                member_id,
3077                &session.peer,
3078            )
3079        }
3080        db::SetMemberRoleResult::InviteUpdated(channel) => {
3081            let update = proto::UpdateChannels {
3082                channel_invitations: vec![channel.to_proto()],
3083                ..Default::default()
3084            };
3085
3086            for connection_id in session
3087                .connection_pool()
3088                .await
3089                .user_connection_ids(member_id)
3090            {
3091                session.peer.send(connection_id, update.clone())?;
3092            }
3093        }
3094    }
3095
3096    response.send(proto::Ack {})?;
3097    Ok(())
3098}
3099
3100/// Change the name of a channel
3101async fn rename_channel(
3102    request: proto::RenameChannel,
3103    response: Response<proto::RenameChannel>,
3104    session: MessageContext,
3105) -> Result<()> {
3106    let db = session.db().await;
3107    let channel_id = ChannelId::from_proto(request.channel_id);
3108    let channel_model = db
3109        .rename_channel(channel_id, session.user_id(), &request.name)
3110        .await?;
3111    let root_id = channel_model.root_id();
3112    let channel = Channel::from_model(channel_model);
3113
3114    response.send(proto::RenameChannelResponse {
3115        channel: Some(channel.to_proto()),
3116    })?;
3117
3118    let connection_pool = session.connection_pool().await;
3119    let update = proto::UpdateChannels {
3120        channels: vec![channel.to_proto()],
3121        ..Default::default()
3122    };
3123    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3124        if role.can_see_channel(channel.visibility) {
3125            session.peer.send(connection_id, update.clone())?;
3126        }
3127    }
3128
3129    Ok(())
3130}
3131
3132/// Move a channel to a new parent.
3133async fn move_channel(
3134    request: proto::MoveChannel,
3135    response: Response<proto::MoveChannel>,
3136    session: MessageContext,
3137) -> Result<()> {
3138    let channel_id = ChannelId::from_proto(request.channel_id);
3139    let to = ChannelId::from_proto(request.to);
3140
3141    let (root_id, channels) = session
3142        .db()
3143        .await
3144        .move_channel(channel_id, to, session.user_id())
3145        .await?;
3146
3147    let connection_pool = session.connection_pool().await;
3148    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3149        let channels = channels
3150            .iter()
3151            .filter_map(|channel| {
3152                if role.can_see_channel(channel.visibility) {
3153                    Some(channel.to_proto())
3154                } else {
3155                    None
3156                }
3157            })
3158            .collect::<Vec<_>>();
3159        if channels.is_empty() {
3160            continue;
3161        }
3162
3163        let update = proto::UpdateChannels {
3164            channels,
3165            ..Default::default()
3166        };
3167
3168        session.peer.send(connection_id, update.clone())?;
3169    }
3170
3171    response.send(Ack {})?;
3172    Ok(())
3173}
3174
3175async fn reorder_channel(
3176    request: proto::ReorderChannel,
3177    response: Response<proto::ReorderChannel>,
3178    session: MessageContext,
3179) -> Result<()> {
3180    let channel_id = ChannelId::from_proto(request.channel_id);
3181    let direction = request.direction();
3182
3183    let updated_channels = session
3184        .db()
3185        .await
3186        .reorder_channel(channel_id, direction, session.user_id())
3187        .await?;
3188
3189    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3190        let connection_pool = session.connection_pool().await;
3191        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3192            let channels = updated_channels
3193                .iter()
3194                .filter_map(|channel| {
3195                    if role.can_see_channel(channel.visibility) {
3196                        Some(channel.to_proto())
3197                    } else {
3198                        None
3199                    }
3200                })
3201                .collect::<Vec<_>>();
3202
3203            if channels.is_empty() {
3204                continue;
3205            }
3206
3207            let update = proto::UpdateChannels {
3208                channels,
3209                ..Default::default()
3210            };
3211
3212            session.peer.send(connection_id, update.clone())?;
3213        }
3214    }
3215
3216    response.send(Ack {})?;
3217    Ok(())
3218}
3219
3220/// Get the list of channel members
3221async fn get_channel_members(
3222    request: proto::GetChannelMembers,
3223    response: Response<proto::GetChannelMembers>,
3224    session: MessageContext,
3225) -> Result<()> {
3226    let db = session.db().await;
3227    let channel_id = ChannelId::from_proto(request.channel_id);
3228    let limit = if request.limit == 0 {
3229        u16::MAX as u64
3230    } else {
3231        request.limit
3232    };
3233    let (members, users) = db
3234        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3235        .await?;
3236    response.send(proto::GetChannelMembersResponse { members, users })?;
3237    Ok(())
3238}
3239
3240/// Accept or decline a channel invitation.
3241async fn respond_to_channel_invite(
3242    request: proto::RespondToChannelInvite,
3243    response: Response<proto::RespondToChannelInvite>,
3244    session: MessageContext,
3245) -> Result<()> {
3246    let db = session.db().await;
3247    let channel_id = ChannelId::from_proto(request.channel_id);
3248    let RespondToChannelInvite {
3249        membership_update,
3250        notifications,
3251    } = db
3252        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3253        .await?;
3254
3255    let mut connection_pool = session.connection_pool().await;
3256    if let Some(membership_update) = membership_update {
3257        notify_membership_updated(
3258            &mut connection_pool,
3259            membership_update,
3260            session.user_id(),
3261            &session.peer,
3262        );
3263    } else {
3264        let update = proto::UpdateChannels {
3265            remove_channel_invitations: vec![channel_id.to_proto()],
3266            ..Default::default()
3267        };
3268
3269        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3270            session.peer.send(connection_id, update.clone())?;
3271        }
3272    };
3273
3274    send_notifications(&connection_pool, &session.peer, notifications);
3275
3276    response.send(proto::Ack {})?;
3277
3278    Ok(())
3279}
3280
3281/// Join the channels' room
3282async fn join_channel(
3283    request: proto::JoinChannel,
3284    response: Response<proto::JoinChannel>,
3285    session: MessageContext,
3286) -> Result<()> {
3287    let channel_id = ChannelId::from_proto(request.channel_id);
3288    join_channel_internal(channel_id, Box::new(response), session).await
3289}
3290
3291trait JoinChannelInternalResponse {
3292    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3293}
3294impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3295    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3296        Response::<proto::JoinChannel>::send(self, result)
3297    }
3298}
3299impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3300    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3301        Response::<proto::JoinRoom>::send(self, result)
3302    }
3303}
3304
3305async fn join_channel_internal(
3306    channel_id: ChannelId,
3307    response: Box<impl JoinChannelInternalResponse>,
3308    session: MessageContext,
3309) -> Result<()> {
3310    let joined_room = {
3311        let mut db = session.db().await;
3312        // If zed quits without leaving the room, and the user re-opens zed before the
3313        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3314        // room they were in.
3315        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3316            tracing::info!(
3317                stale_connection_id = %connection,
3318                "cleaning up stale connection",
3319            );
3320            drop(db);
3321            leave_room_for_session(&session, connection).await?;
3322            db = session.db().await;
3323        }
3324
3325        let (joined_room, membership_updated, role) = db
3326            .join_channel(channel_id, session.user_id(), session.connection_id)
3327            .await?;
3328
3329        let live_kit_connection_info =
3330            session
3331                .app_state
3332                .livekit_client
3333                .as_ref()
3334                .and_then(|live_kit| {
3335                    let (can_publish, token) = if role == ChannelRole::Guest {
3336                        (
3337                            false,
3338                            live_kit
3339                                .guest_token(
3340                                    &joined_room.room.livekit_room,
3341                                    &session.user_id().to_string(),
3342                                )
3343                                .trace_err()?,
3344                        )
3345                    } else {
3346                        (
3347                            true,
3348                            live_kit
3349                                .room_token(
3350                                    &joined_room.room.livekit_room,
3351                                    &session.user_id().to_string(),
3352                                )
3353                                .trace_err()?,
3354                        )
3355                    };
3356
3357                    Some(LiveKitConnectionInfo {
3358                        server_url: live_kit.url().into(),
3359                        token,
3360                        can_publish,
3361                    })
3362                });
3363
3364        response.send(proto::JoinRoomResponse {
3365            room: Some(joined_room.room.clone()),
3366            channel_id: joined_room
3367                .channel
3368                .as_ref()
3369                .map(|channel| channel.id.to_proto()),
3370            live_kit_connection_info,
3371        })?;
3372
3373        let mut connection_pool = session.connection_pool().await;
3374        if let Some(membership_updated) = membership_updated {
3375            notify_membership_updated(
3376                &mut connection_pool,
3377                membership_updated,
3378                session.user_id(),
3379                &session.peer,
3380            );
3381        }
3382
3383        room_updated(&joined_room.room, &session.peer);
3384
3385        joined_room
3386    };
3387
3388    channel_updated(
3389        &joined_room.channel.context("channel not returned")?,
3390        &joined_room.room,
3391        &session.peer,
3392        &*session.connection_pool().await,
3393    );
3394
3395    update_user_contacts(session.user_id(), &session).await?;
3396    Ok(())
3397}
3398
3399/// Start editing the channel notes
3400async fn join_channel_buffer(
3401    request: proto::JoinChannelBuffer,
3402    response: Response<proto::JoinChannelBuffer>,
3403    session: MessageContext,
3404) -> Result<()> {
3405    let db = session.db().await;
3406    let channel_id = ChannelId::from_proto(request.channel_id);
3407
3408    let open_response = db
3409        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3410        .await?;
3411
3412    let collaborators = open_response.collaborators.clone();
3413    response.send(open_response)?;
3414
3415    let update = UpdateChannelBufferCollaborators {
3416        channel_id: channel_id.to_proto(),
3417        collaborators: collaborators.clone(),
3418    };
3419    channel_buffer_updated(
3420        session.connection_id,
3421        collaborators
3422            .iter()
3423            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3424        &update,
3425        &session.peer,
3426    );
3427
3428    Ok(())
3429}
3430
3431/// Edit the channel notes
3432async fn update_channel_buffer(
3433    request: proto::UpdateChannelBuffer,
3434    session: MessageContext,
3435) -> Result<()> {
3436    let db = session.db().await;
3437    let channel_id = ChannelId::from_proto(request.channel_id);
3438
3439    let (collaborators, epoch, version) = db
3440        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3441        .await?;
3442
3443    channel_buffer_updated(
3444        session.connection_id,
3445        collaborators.clone(),
3446        &proto::UpdateChannelBuffer {
3447            channel_id: channel_id.to_proto(),
3448            operations: request.operations,
3449        },
3450        &session.peer,
3451    );
3452
3453    let pool = &*session.connection_pool().await;
3454
3455    let non_collaborators =
3456        pool.channel_connection_ids(channel_id)
3457            .filter_map(|(connection_id, _)| {
3458                if collaborators.contains(&connection_id) {
3459                    None
3460                } else {
3461                    Some(connection_id)
3462                }
3463            });
3464
3465    broadcast(None, non_collaborators, |peer_id| {
3466        session.peer.send(
3467            peer_id,
3468            proto::UpdateChannels {
3469                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3470                    channel_id: channel_id.to_proto(),
3471                    epoch: epoch as u64,
3472                    version: version.clone(),
3473                }],
3474                ..Default::default()
3475            },
3476        )
3477    });
3478
3479    Ok(())
3480}
3481
3482/// Rejoin the channel notes after a connection blip
3483async fn rejoin_channel_buffers(
3484    request: proto::RejoinChannelBuffers,
3485    response: Response<proto::RejoinChannelBuffers>,
3486    session: MessageContext,
3487) -> Result<()> {
3488    let db = session.db().await;
3489    let buffers = db
3490        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3491        .await?;
3492
3493    for rejoined_buffer in &buffers {
3494        let collaborators_to_notify = rejoined_buffer
3495            .buffer
3496            .collaborators
3497            .iter()
3498            .filter_map(|c| Some(c.peer_id?.into()));
3499        channel_buffer_updated(
3500            session.connection_id,
3501            collaborators_to_notify,
3502            &proto::UpdateChannelBufferCollaborators {
3503                channel_id: rejoined_buffer.buffer.channel_id,
3504                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3505            },
3506            &session.peer,
3507        );
3508    }
3509
3510    response.send(proto::RejoinChannelBuffersResponse {
3511        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3512    })?;
3513
3514    Ok(())
3515}
3516
3517/// Stop editing the channel notes
3518async fn leave_channel_buffer(
3519    request: proto::LeaveChannelBuffer,
3520    response: Response<proto::LeaveChannelBuffer>,
3521    session: MessageContext,
3522) -> Result<()> {
3523    let db = session.db().await;
3524    let channel_id = ChannelId::from_proto(request.channel_id);
3525
3526    let left_buffer = db
3527        .leave_channel_buffer(channel_id, session.connection_id)
3528        .await?;
3529
3530    response.send(Ack {})?;
3531
3532    channel_buffer_updated(
3533        session.connection_id,
3534        left_buffer.connections,
3535        &proto::UpdateChannelBufferCollaborators {
3536            channel_id: channel_id.to_proto(),
3537            collaborators: left_buffer.collaborators,
3538        },
3539        &session.peer,
3540    );
3541
3542    Ok(())
3543}
3544
3545fn channel_buffer_updated<T: EnvelopedMessage>(
3546    sender_id: ConnectionId,
3547    collaborators: impl IntoIterator<Item = ConnectionId>,
3548    message: &T,
3549    peer: &Peer,
3550) {
3551    broadcast(Some(sender_id), collaborators, |peer_id| {
3552        peer.send(peer_id, message.clone())
3553    });
3554}
3555
3556fn send_notifications(
3557    connection_pool: &ConnectionPool,
3558    peer: &Peer,
3559    notifications: db::NotificationBatch,
3560) {
3561    for (user_id, notification) in notifications {
3562        for connection_id in connection_pool.user_connection_ids(user_id) {
3563            if let Err(error) = peer.send(
3564                connection_id,
3565                proto::AddNotification {
3566                    notification: Some(notification.clone()),
3567                },
3568            ) {
3569                tracing::error!(
3570                    "failed to send notification to {:?} {}",
3571                    connection_id,
3572                    error
3573                );
3574            }
3575        }
3576    }
3577}
3578
3579/// Send a message to the channel
3580async fn send_channel_message(
3581    _request: proto::SendChannelMessage,
3582    _response: Response<proto::SendChannelMessage>,
3583    _session: MessageContext,
3584) -> Result<()> {
3585    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3586}
3587
3588/// Delete a channel message
3589async fn remove_channel_message(
3590    _request: proto::RemoveChannelMessage,
3591    _response: Response<proto::RemoveChannelMessage>,
3592    _session: MessageContext,
3593) -> Result<()> {
3594    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3595}
3596
3597async fn update_channel_message(
3598    _request: proto::UpdateChannelMessage,
3599    _response: Response<proto::UpdateChannelMessage>,
3600    _session: MessageContext,
3601) -> Result<()> {
3602    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3603}
3604
3605/// Mark a channel message as read
3606async fn acknowledge_channel_message(
3607    _request: proto::AckChannelMessage,
3608    _session: MessageContext,
3609) -> Result<()> {
3610    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3611}
3612
3613/// Mark a buffer version as synced
3614async fn acknowledge_buffer_version(
3615    request: proto::AckBufferOperation,
3616    session: MessageContext,
3617) -> Result<()> {
3618    let buffer_id = BufferId::from_proto(request.buffer_id);
3619    session
3620        .db()
3621        .await
3622        .observe_buffer_version(
3623            buffer_id,
3624            session.user_id(),
3625            request.epoch as i32,
3626            &request.version,
3627        )
3628        .await?;
3629    Ok(())
3630}
3631
3632/// Get a Supermaven API key for the user
3633async fn get_supermaven_api_key(
3634    _request: proto::GetSupermavenApiKey,
3635    response: Response<proto::GetSupermavenApiKey>,
3636    session: MessageContext,
3637) -> Result<()> {
3638    let user_id: String = session.user_id().to_string();
3639    if !session.is_staff() {
3640        return Err(anyhow!("supermaven not enabled for this account"))?;
3641    }
3642
3643    let email = session.email().context("user must have an email")?;
3644
3645    let supermaven_admin_api = session
3646        .supermaven_client
3647        .as_ref()
3648        .context("supermaven not configured")?;
3649
3650    let result = supermaven_admin_api
3651        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3652        .await?;
3653
3654    response.send(proto::GetSupermavenApiKeyResponse {
3655        api_key: result.api_key,
3656    })?;
3657
3658    Ok(())
3659}
3660
3661/// Start receiving chat updates for a channel
3662async fn join_channel_chat(
3663    _request: proto::JoinChannelChat,
3664    _response: Response<proto::JoinChannelChat>,
3665    _session: MessageContext,
3666) -> Result<()> {
3667    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3668}
3669
3670/// Stop receiving chat updates for a channel
3671async fn leave_channel_chat(
3672    _request: proto::LeaveChannelChat,
3673    _session: MessageContext,
3674) -> Result<()> {
3675    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3676}
3677
3678/// Retrieve the chat history for a channel
3679async fn get_channel_messages(
3680    _request: proto::GetChannelMessages,
3681    _response: Response<proto::GetChannelMessages>,
3682    _session: MessageContext,
3683) -> Result<()> {
3684    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3685}
3686
3687/// Retrieve specific chat messages
3688async fn get_channel_messages_by_id(
3689    _request: proto::GetChannelMessagesById,
3690    _response: Response<proto::GetChannelMessagesById>,
3691    _session: MessageContext,
3692) -> Result<()> {
3693    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3694}
3695
3696/// Retrieve the current users notifications
3697async fn get_notifications(
3698    request: proto::GetNotifications,
3699    response: Response<proto::GetNotifications>,
3700    session: MessageContext,
3701) -> Result<()> {
3702    let notifications = session
3703        .db()
3704        .await
3705        .get_notifications(
3706            session.user_id(),
3707            NOTIFICATION_COUNT_PER_PAGE,
3708            request.before_id.map(db::NotificationId::from_proto),
3709        )
3710        .await?;
3711    response.send(proto::GetNotificationsResponse {
3712        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3713        notifications,
3714    })?;
3715    Ok(())
3716}
3717
3718/// Mark notifications as read
3719async fn mark_notification_as_read(
3720    request: proto::MarkNotificationRead,
3721    response: Response<proto::MarkNotificationRead>,
3722    session: MessageContext,
3723) -> Result<()> {
3724    let database = &session.db().await;
3725    let notifications = database
3726        .mark_notification_as_read_by_id(
3727            session.user_id(),
3728            NotificationId::from_proto(request.notification_id),
3729        )
3730        .await?;
3731    send_notifications(
3732        &*session.connection_pool().await,
3733        &session.peer,
3734        notifications,
3735    );
3736    response.send(proto::Ack {})?;
3737    Ok(())
3738}
3739
3740fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3741    let message = match message {
3742        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3743        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3744        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3745        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3746        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3747            code: frame.code.into(),
3748            reason: frame.reason.as_str().to_owned().into(),
3749        })),
3750        // We should never receive a frame while reading the message, according
3751        // to the `tungstenite` maintainers:
3752        //
3753        // > It cannot occur when you read messages from the WebSocket, but it
3754        // > can be used when you want to send the raw frames (e.g. you want to
3755        // > send the frames to the WebSocket without composing the full message first).
3756        // >
3757        // > — https://github.com/snapview/tungstenite-rs/issues/268
3758        TungsteniteMessage::Frame(_) => {
3759            bail!("received an unexpected frame while reading the message")
3760        }
3761    };
3762
3763    Ok(message)
3764}
3765
3766fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3767    match message {
3768        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3769        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3770        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3771        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3772        AxumMessage::Close(frame) => {
3773            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3774                code: frame.code.into(),
3775                reason: frame.reason.as_ref().into(),
3776            }))
3777        }
3778    }
3779}
3780
3781fn notify_membership_updated(
3782    connection_pool: &mut ConnectionPool,
3783    result: MembershipUpdated,
3784    user_id: UserId,
3785    peer: &Peer,
3786) {
3787    for membership in &result.new_channels.channel_memberships {
3788        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3789    }
3790    for channel_id in &result.removed_channels {
3791        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3792    }
3793
3794    let user_channels_update = proto::UpdateUserChannels {
3795        channel_memberships: result
3796            .new_channels
3797            .channel_memberships
3798            .iter()
3799            .map(|cm| proto::ChannelMembership {
3800                channel_id: cm.channel_id.to_proto(),
3801                role: cm.role.into(),
3802            })
3803            .collect(),
3804        ..Default::default()
3805    };
3806
3807    let mut update = build_channels_update(result.new_channels);
3808    update.delete_channels = result
3809        .removed_channels
3810        .into_iter()
3811        .map(|id| id.to_proto())
3812        .collect();
3813    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3814
3815    for connection_id in connection_pool.user_connection_ids(user_id) {
3816        peer.send(connection_id, user_channels_update.clone())
3817            .trace_err();
3818        peer.send(connection_id, update.clone()).trace_err();
3819    }
3820}
3821
3822fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3823    proto::UpdateUserChannels {
3824        channel_memberships: channels
3825            .channel_memberships
3826            .iter()
3827            .map(|m| proto::ChannelMembership {
3828                channel_id: m.channel_id.to_proto(),
3829                role: m.role.into(),
3830            })
3831            .collect(),
3832        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3833    }
3834}
3835
3836fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3837    let mut update = proto::UpdateChannels::default();
3838
3839    for channel in channels.channels {
3840        update.channels.push(channel.to_proto());
3841    }
3842
3843    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3844
3845    for (channel_id, participants) in channels.channel_participants {
3846        update
3847            .channel_participants
3848            .push(proto::ChannelParticipants {
3849                channel_id: channel_id.to_proto(),
3850                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3851            });
3852    }
3853
3854    for channel in channels.invited_channels {
3855        update.channel_invitations.push(channel.to_proto());
3856    }
3857
3858    update
3859}
3860
3861fn build_initial_contacts_update(
3862    contacts: Vec<db::Contact>,
3863    pool: &ConnectionPool,
3864) -> proto::UpdateContacts {
3865    let mut update = proto::UpdateContacts::default();
3866
3867    for contact in contacts {
3868        match contact {
3869            db::Contact::Accepted { user_id, busy } => {
3870                update.contacts.push(contact_for_user(user_id, busy, pool));
3871            }
3872            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3873            db::Contact::Incoming { user_id } => {
3874                update
3875                    .incoming_requests
3876                    .push(proto::IncomingContactRequest {
3877                        requester_id: user_id.to_proto(),
3878                    })
3879            }
3880        }
3881    }
3882
3883    update
3884}
3885
3886fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3887    proto::Contact {
3888        user_id: user_id.to_proto(),
3889        online: pool.is_user_online(user_id),
3890        busy,
3891    }
3892}
3893
3894fn room_updated(room: &proto::Room, peer: &Peer) {
3895    broadcast(
3896        None,
3897        room.participants
3898            .iter()
3899            .filter_map(|participant| Some(participant.peer_id?.into())),
3900        |peer_id| {
3901            peer.send(
3902                peer_id,
3903                proto::RoomUpdated {
3904                    room: Some(room.clone()),
3905                },
3906            )
3907        },
3908    );
3909}
3910
3911fn channel_updated(
3912    channel: &db::channel::Model,
3913    room: &proto::Room,
3914    peer: &Peer,
3915    pool: &ConnectionPool,
3916) {
3917    let participants = room
3918        .participants
3919        .iter()
3920        .map(|p| p.user_id)
3921        .collect::<Vec<_>>();
3922
3923    broadcast(
3924        None,
3925        pool.channel_connection_ids(channel.root_id())
3926            .filter_map(|(channel_id, role)| {
3927                role.can_see_channel(channel.visibility)
3928                    .then_some(channel_id)
3929            }),
3930        |peer_id| {
3931            peer.send(
3932                peer_id,
3933                proto::UpdateChannels {
3934                    channel_participants: vec![proto::ChannelParticipants {
3935                        channel_id: channel.id.to_proto(),
3936                        participant_user_ids: participants.clone(),
3937                    }],
3938                    ..Default::default()
3939                },
3940            )
3941        },
3942    );
3943}
3944
3945async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3946    let db = session.db().await;
3947
3948    let contacts = db.get_contacts(user_id).await?;
3949    let busy = db.is_user_busy(user_id).await?;
3950
3951    let pool = session.connection_pool().await;
3952    let updated_contact = contact_for_user(user_id, busy, &pool);
3953    for contact in contacts {
3954        if let db::Contact::Accepted {
3955            user_id: contact_user_id,
3956            ..
3957        } = contact
3958        {
3959            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3960                session
3961                    .peer
3962                    .send(
3963                        contact_conn_id,
3964                        proto::UpdateContacts {
3965                            contacts: vec![updated_contact.clone()],
3966                            remove_contacts: Default::default(),
3967                            incoming_requests: Default::default(),
3968                            remove_incoming_requests: Default::default(),
3969                            outgoing_requests: Default::default(),
3970                            remove_outgoing_requests: Default::default(),
3971                        },
3972                    )
3973                    .trace_err();
3974            }
3975        }
3976    }
3977    Ok(())
3978}
3979
3980async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3981    let mut contacts_to_update = HashSet::default();
3982
3983    let room_id;
3984    let canceled_calls_to_user_ids;
3985    let livekit_room;
3986    let delete_livekit_room;
3987    let room;
3988    let channel;
3989
3990    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3991        contacts_to_update.insert(session.user_id());
3992
3993        for project in left_room.left_projects.values() {
3994            project_left(project, session);
3995        }
3996
3997        room_id = RoomId::from_proto(left_room.room.id);
3998        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3999        livekit_room = mem::take(&mut left_room.room.livekit_room);
4000        delete_livekit_room = left_room.deleted;
4001        room = mem::take(&mut left_room.room);
4002        channel = mem::take(&mut left_room.channel);
4003
4004        room_updated(&room, &session.peer);
4005    } else {
4006        return Ok(());
4007    }
4008
4009    if let Some(channel) = channel {
4010        channel_updated(
4011            &channel,
4012            &room,
4013            &session.peer,
4014            &*session.connection_pool().await,
4015        );
4016    }
4017
4018    {
4019        let pool = session.connection_pool().await;
4020        for canceled_user_id in canceled_calls_to_user_ids {
4021            for connection_id in pool.user_connection_ids(canceled_user_id) {
4022                session
4023                    .peer
4024                    .send(
4025                        connection_id,
4026                        proto::CallCanceled {
4027                            room_id: room_id.to_proto(),
4028                        },
4029                    )
4030                    .trace_err();
4031            }
4032            contacts_to_update.insert(canceled_user_id);
4033        }
4034    }
4035
4036    for contact_user_id in contacts_to_update {
4037        update_user_contacts(contact_user_id, session).await?;
4038    }
4039
4040    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4041        live_kit
4042            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4043            .await
4044            .trace_err();
4045
4046        if delete_livekit_room {
4047            live_kit.delete_room(livekit_room).await.trace_err();
4048        }
4049    }
4050
4051    Ok(())
4052}
4053
4054async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4055    let left_channel_buffers = session
4056        .db()
4057        .await
4058        .leave_channel_buffers(session.connection_id)
4059        .await?;
4060
4061    for left_buffer in left_channel_buffers {
4062        channel_buffer_updated(
4063            session.connection_id,
4064            left_buffer.connections,
4065            &proto::UpdateChannelBufferCollaborators {
4066                channel_id: left_buffer.channel_id.to_proto(),
4067                collaborators: left_buffer.collaborators,
4068            },
4069            &session.peer,
4070        );
4071    }
4072
4073    Ok(())
4074}
4075
4076fn project_left(project: &db::LeftProject, session: &Session) {
4077    for connection_id in &project.connection_ids {
4078        if project.should_unshare {
4079            session
4080                .peer
4081                .send(
4082                    *connection_id,
4083                    proto::UnshareProject {
4084                        project_id: project.id.to_proto(),
4085                    },
4086                )
4087                .trace_err();
4088        } else {
4089            session
4090                .peer
4091                .send(
4092                    *connection_id,
4093                    proto::RemoveProjectCollaborator {
4094                        project_id: project.id.to_proto(),
4095                        peer_id: Some(session.connection_id.into()),
4096                    },
4097                )
4098                .trace_err();
4099        }
4100    }
4101}
4102
4103pub trait ResultExt {
4104    type Ok;
4105
4106    fn trace_err(self) -> Option<Self::Ok>;
4107}
4108
4109impl<T, E> ResultExt for Result<T, E>
4110where
4111    E: std::fmt::Debug,
4112{
4113    type Ok = T;
4114
4115    #[track_caller]
4116    fn trace_err(self) -> Option<T> {
4117        match self {
4118            Ok(value) => Some(value),
4119            Err(error) => {
4120                tracing::error!("{:?}", error);
4121                None
4122            }
4123        }
4124    }
4125}