rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   8        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   9        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  10        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use reqwest_client::ReqwestClient;
  37use rpc::proto::{MultiLspQuery, split_repository_update};
  38use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  39use tracing::Span;
  40
  41use futures::{
  42    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  43    stream::FuturesUnordered,
  44};
  45use prometheus::{IntGauge, register_int_gauge};
  46use rpc::{
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52};
  53use semantic_version::SemanticVersion;
  54use serde::{Serialize, Serializer};
  55use std::{
  56    any::TypeId,
  57    future::Future,
  58    marker::PhantomData,
  59    mem,
  60    net::SocketAddr,
  61    ops::{Deref, DerefMut},
  62    rc::Rc,
  63    sync::{
  64        Arc, OnceLock,
  65        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  66    },
  67    time::{Duration, Instant},
  68};
  69use time::OffsetDateTime;
  70use tokio::sync::{Semaphore, watch};
  71use tower::ServiceBuilder;
  72use tracing::{
  73    Instrument,
  74    field::{self},
  75    info_span, instrument,
  76};
  77
  78pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  79
  80// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  81pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  82
  83const MESSAGE_COUNT_PER_PAGE: usize = 100;
  84const MAX_MESSAGE_LEN: usize = 1024;
  85const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  86const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  87
  88static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  89
  90const TOTAL_DURATION_MS: &str = "total_duration_ms";
  91const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  92const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  93const HOST_WAITING_MS: &str = "host_waiting_ms";
  94
  95type MessageHandler =
  96    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  97
  98pub struct ConnectionGuard;
  99
 100impl ConnectionGuard {
 101    pub fn try_acquire() -> Result<Self, ()> {
 102        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 103        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 104            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 105            tracing::error!(
 106                "too many concurrent connections: {}",
 107                current_connections + 1
 108            );
 109            return Err(());
 110        }
 111        Ok(ConnectionGuard)
 112    }
 113}
 114
 115impl Drop for ConnectionGuard {
 116    fn drop(&mut self) {
 117        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 118    }
 119}
 120
 121struct Response<R> {
 122    peer: Arc<Peer>,
 123    receipt: Receipt<R>,
 124    responded: Arc<AtomicBool>,
 125}
 126
 127impl<R: RequestMessage> Response<R> {
 128    fn send(self, payload: R::Response) -> Result<()> {
 129        self.responded.store(true, SeqCst);
 130        self.peer.respond(self.receipt, payload)?;
 131        Ok(())
 132    }
 133}
 134
 135#[derive(Clone, Debug)]
 136pub enum Principal {
 137    User(User),
 138    Impersonated { user: User, admin: User },
 139}
 140
 141impl Principal {
 142    fn update_span(&self, span: &tracing::Span) {
 143        match &self {
 144            Principal::User(user) => {
 145                span.record("user_id", user.id.0);
 146                span.record("login", &user.github_login);
 147            }
 148            Principal::Impersonated { user, admin } => {
 149                span.record("user_id", user.id.0);
 150                span.record("login", &user.github_login);
 151                span.record("impersonator", &admin.github_login);
 152            }
 153        }
 154    }
 155}
 156
 157#[derive(Clone)]
 158struct MessageContext {
 159    session: Session,
 160    span: tracing::Span,
 161}
 162
 163impl Deref for MessageContext {
 164    type Target = Session;
 165
 166    fn deref(&self) -> &Self::Target {
 167        &self.session
 168    }
 169}
 170
 171impl MessageContext {
 172    pub fn forward_request<T: RequestMessage>(
 173        &self,
 174        receiver_id: ConnectionId,
 175        request: T,
 176    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 177        let request_start_time = Instant::now();
 178        let span = self.span.clone();
 179        tracing::info!("start forwarding request");
 180        self.peer
 181            .forward_request(self.connection_id, receiver_id, request)
 182            .inspect(move |_| {
 183                span.record(
 184                    HOST_WAITING_MS,
 185                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 186                );
 187            })
 188            .inspect_err(|_| tracing::error!("error forwarding request"))
 189            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 190    }
 191}
 192
 193#[derive(Clone)]
 194struct Session {
 195    principal: Principal,
 196    connection_id: ConnectionId,
 197    db: Arc<tokio::sync::Mutex<DbHandle>>,
 198    peer: Arc<Peer>,
 199    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 200    app_state: Arc<AppState>,
 201    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 202    /// The GeoIP country code for the user.
 203    #[allow(unused)]
 204    geoip_country_code: Option<String>,
 205    #[allow(unused)]
 206    system_id: Option<String>,
 207    _executor: Executor,
 208}
 209
 210impl Session {
 211    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 212        #[cfg(test)]
 213        tokio::task::yield_now().await;
 214        let guard = self.db.lock().await;
 215        #[cfg(test)]
 216        tokio::task::yield_now().await;
 217        guard
 218    }
 219
 220    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 221        #[cfg(test)]
 222        tokio::task::yield_now().await;
 223        let guard = self.connection_pool.lock();
 224        ConnectionPoolGuard {
 225            guard,
 226            _not_send: PhantomData,
 227        }
 228    }
 229
 230    fn is_staff(&self) -> bool {
 231        match &self.principal {
 232            Principal::User(user) => user.admin,
 233            Principal::Impersonated { .. } => true,
 234        }
 235    }
 236
 237    fn user_id(&self) -> UserId {
 238        match &self.principal {
 239            Principal::User(user) => user.id,
 240            Principal::Impersonated { user, .. } => user.id,
 241        }
 242    }
 243
 244    pub fn email(&self) -> Option<String> {
 245        match &self.principal {
 246            Principal::User(user) => user.email_address.clone(),
 247            Principal::Impersonated { user, .. } => user.email_address.clone(),
 248        }
 249    }
 250}
 251
 252impl Debug for Session {
 253    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 254        let mut result = f.debug_struct("Session");
 255        match &self.principal {
 256            Principal::User(user) => {
 257                result.field("user", &user.github_login);
 258            }
 259            Principal::Impersonated { user, admin } => {
 260                result.field("user", &user.github_login);
 261                result.field("impersonator", &admin.github_login);
 262            }
 263        }
 264        result.field("connection_id", &self.connection_id).finish()
 265    }
 266}
 267
 268struct DbHandle(Arc<Database>);
 269
 270impl Deref for DbHandle {
 271    type Target = Database;
 272
 273    fn deref(&self) -> &Self::Target {
 274        self.0.as_ref()
 275    }
 276}
 277
 278pub struct Server {
 279    id: parking_lot::Mutex<ServerId>,
 280    peer: Arc<Peer>,
 281    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 282    app_state: Arc<AppState>,
 283    handlers: HashMap<TypeId, MessageHandler>,
 284    teardown: watch::Sender<bool>,
 285}
 286
 287pub(crate) struct ConnectionPoolGuard<'a> {
 288    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 289    _not_send: PhantomData<Rc<()>>,
 290}
 291
 292#[derive(Serialize)]
 293pub struct ServerSnapshot<'a> {
 294    peer: &'a Peer,
 295    #[serde(serialize_with = "serialize_deref")]
 296    connection_pool: ConnectionPoolGuard<'a>,
 297}
 298
 299pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 300where
 301    S: Serializer,
 302    T: Deref<Target = U>,
 303    U: Serialize,
 304{
 305    Serialize::serialize(value.deref(), serializer)
 306}
 307
 308impl Server {
 309    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 310        let mut server = Self {
 311            id: parking_lot::Mutex::new(id),
 312            peer: Peer::new(id.0 as u32),
 313            app_state,
 314            connection_pool: Default::default(),
 315            handlers: Default::default(),
 316            teardown: watch::channel(false).0,
 317        };
 318
 319        server
 320            .add_request_handler(ping)
 321            .add_request_handler(create_room)
 322            .add_request_handler(join_room)
 323            .add_request_handler(rejoin_room)
 324            .add_request_handler(leave_room)
 325            .add_request_handler(set_room_participant_role)
 326            .add_request_handler(call)
 327            .add_request_handler(cancel_call)
 328            .add_message_handler(decline_call)
 329            .add_request_handler(update_participant_location)
 330            .add_request_handler(share_project)
 331            .add_message_handler(unshare_project)
 332            .add_request_handler(join_project)
 333            .add_message_handler(leave_project)
 334            .add_request_handler(update_project)
 335            .add_request_handler(update_worktree)
 336            .add_request_handler(update_repository)
 337            .add_request_handler(remove_repository)
 338            .add_message_handler(start_language_server)
 339            .add_message_handler(update_language_server)
 340            .add_message_handler(update_diagnostic_summary)
 341            .add_message_handler(update_worktree_settings)
 342            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 344            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 345            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 346            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 347            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 348            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 349            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 350            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 351            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 352            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 353            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 354            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 355            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 356            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 357            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 358            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 359            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 360            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 361            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 362            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 363            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 364            .add_request_handler(
 365                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 366            )
 367            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 368            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 369            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 370            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 371            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 372            .add_request_handler(
 373                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 374            )
 375            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 376            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 377            .add_request_handler(
 378                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 379            )
 380            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 381            .add_request_handler(
 382                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 383            )
 384            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 385            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 386            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 387            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 388            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 389            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 390            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 391            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 392            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 393            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 394            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 395            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 396            .add_request_handler(
 397                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 398            )
 399            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 400            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 401            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 402            .add_request_handler(multi_lsp_query)
 403            .add_request_handler(lsp_query)
 404            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 405            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 406            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 407            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 408            .add_message_handler(create_buffer_for_peer)
 409            .add_request_handler(update_buffer)
 410            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 411            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 412            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 413            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 414            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 415            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 416            .add_message_handler(
 417                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 418            )
 419            .add_request_handler(get_users)
 420            .add_request_handler(fuzzy_search_users)
 421            .add_request_handler(request_contact)
 422            .add_request_handler(remove_contact)
 423            .add_request_handler(respond_to_contact_request)
 424            .add_message_handler(subscribe_to_channels)
 425            .add_request_handler(create_channel)
 426            .add_request_handler(delete_channel)
 427            .add_request_handler(invite_channel_member)
 428            .add_request_handler(remove_channel_member)
 429            .add_request_handler(set_channel_member_role)
 430            .add_request_handler(set_channel_visibility)
 431            .add_request_handler(rename_channel)
 432            .add_request_handler(join_channel_buffer)
 433            .add_request_handler(leave_channel_buffer)
 434            .add_message_handler(update_channel_buffer)
 435            .add_request_handler(rejoin_channel_buffers)
 436            .add_request_handler(get_channel_members)
 437            .add_request_handler(respond_to_channel_invite)
 438            .add_request_handler(join_channel)
 439            .add_request_handler(join_channel_chat)
 440            .add_message_handler(leave_channel_chat)
 441            .add_request_handler(send_channel_message)
 442            .add_request_handler(remove_channel_message)
 443            .add_request_handler(update_channel_message)
 444            .add_request_handler(get_channel_messages)
 445            .add_request_handler(get_channel_messages_by_id)
 446            .add_request_handler(get_notifications)
 447            .add_request_handler(mark_notification_as_read)
 448            .add_request_handler(move_channel)
 449            .add_request_handler(reorder_channel)
 450            .add_request_handler(follow)
 451            .add_message_handler(unfollow)
 452            .add_message_handler(update_followers)
 453            .add_message_handler(acknowledge_channel_message)
 454            .add_message_handler(acknowledge_buffer_version)
 455            .add_request_handler(get_supermaven_api_key)
 456            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 457            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 458            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 459            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 460            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 461            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 462            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 463            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 464            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 465            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 466            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 467            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 468            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 469            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 470            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 471            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 472            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 473            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 474            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 475            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 476            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 477            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 478            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 479            .add_message_handler(update_context)
 480            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 481            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
 482
 483        Arc::new(server)
 484    }
 485
 486    pub async fn start(&self) -> Result<()> {
 487        let server_id = *self.id.lock();
 488        let app_state = self.app_state.clone();
 489        let peer = self.peer.clone();
 490        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 491        let pool = self.connection_pool.clone();
 492        let livekit_client = self.app_state.livekit_client.clone();
 493
 494        let span = info_span!("start server");
 495        self.app_state.executor.spawn_detached(
 496            async move {
 497                tracing::info!("waiting for cleanup timeout");
 498                timeout.await;
 499                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 500
 501                app_state
 502                    .db
 503                    .delete_stale_channel_chat_participants(
 504                        &app_state.config.zed_environment,
 505                        server_id,
 506                    )
 507                    .await
 508                    .trace_err();
 509
 510                if let Some((room_ids, channel_ids)) = app_state
 511                    .db
 512                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 513                    .await
 514                    .trace_err()
 515                {
 516                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 517                    tracing::info!(
 518                        stale_channel_buffer_count = channel_ids.len(),
 519                        "retrieved stale channel buffers"
 520                    );
 521
 522                    for channel_id in channel_ids {
 523                        if let Some(refreshed_channel_buffer) = app_state
 524                            .db
 525                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 526                            .await
 527                            .trace_err()
 528                        {
 529                            for connection_id in refreshed_channel_buffer.connection_ids {
 530                                peer.send(
 531                                    connection_id,
 532                                    proto::UpdateChannelBufferCollaborators {
 533                                        channel_id: channel_id.to_proto(),
 534                                        collaborators: refreshed_channel_buffer
 535                                            .collaborators
 536                                            .clone(),
 537                                    },
 538                                )
 539                                .trace_err();
 540                            }
 541                        }
 542                    }
 543
 544                    for room_id in room_ids {
 545                        let mut contacts_to_update = HashSet::default();
 546                        let mut canceled_calls_to_user_ids = Vec::new();
 547                        let mut livekit_room = String::new();
 548                        let mut delete_livekit_room = false;
 549
 550                        if let Some(mut refreshed_room) = app_state
 551                            .db
 552                            .clear_stale_room_participants(room_id, server_id)
 553                            .await
 554                            .trace_err()
 555                        {
 556                            tracing::info!(
 557                                room_id = room_id.0,
 558                                new_participant_count = refreshed_room.room.participants.len(),
 559                                "refreshed room"
 560                            );
 561                            room_updated(&refreshed_room.room, &peer);
 562                            if let Some(channel) = refreshed_room.channel.as_ref() {
 563                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 564                            }
 565                            contacts_to_update
 566                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 567                            contacts_to_update
 568                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 569                            canceled_calls_to_user_ids =
 570                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 571                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 572                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 573                        }
 574
 575                        {
 576                            let pool = pool.lock();
 577                            for canceled_user_id in canceled_calls_to_user_ids {
 578                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 579                                    peer.send(
 580                                        connection_id,
 581                                        proto::CallCanceled {
 582                                            room_id: room_id.to_proto(),
 583                                        },
 584                                    )
 585                                    .trace_err();
 586                                }
 587                            }
 588                        }
 589
 590                        for user_id in contacts_to_update {
 591                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 592                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 593                            if let Some((busy, contacts)) = busy.zip(contacts) {
 594                                let pool = pool.lock();
 595                                let updated_contact = contact_for_user(user_id, busy, &pool);
 596                                for contact in contacts {
 597                                    if let db::Contact::Accepted {
 598                                        user_id: contact_user_id,
 599                                        ..
 600                                    } = contact
 601                                    {
 602                                        for contact_conn_id in
 603                                            pool.user_connection_ids(contact_user_id)
 604                                        {
 605                                            peer.send(
 606                                                contact_conn_id,
 607                                                proto::UpdateContacts {
 608                                                    contacts: vec![updated_contact.clone()],
 609                                                    remove_contacts: Default::default(),
 610                                                    incoming_requests: Default::default(),
 611                                                    remove_incoming_requests: Default::default(),
 612                                                    outgoing_requests: Default::default(),
 613                                                    remove_outgoing_requests: Default::default(),
 614                                                },
 615                                            )
 616                                            .trace_err();
 617                                        }
 618                                    }
 619                                }
 620                            }
 621                        }
 622
 623                        if let Some(live_kit) = livekit_client.as_ref()
 624                            && delete_livekit_room
 625                        {
 626                            live_kit.delete_room(livekit_room).await.trace_err();
 627                        }
 628                    }
 629                }
 630
 631                app_state
 632                    .db
 633                    .delete_stale_channel_chat_participants(
 634                        &app_state.config.zed_environment,
 635                        server_id,
 636                    )
 637                    .await
 638                    .trace_err();
 639
 640                app_state
 641                    .db
 642                    .clear_old_worktree_entries(server_id)
 643                    .await
 644                    .trace_err();
 645
 646                app_state
 647                    .db
 648                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 649                    .await
 650                    .trace_err();
 651            }
 652            .instrument(span),
 653        );
 654        Ok(())
 655    }
 656
 657    pub fn teardown(&self) {
 658        self.peer.teardown();
 659        self.connection_pool.lock().reset();
 660        let _ = self.teardown.send(true);
 661    }
 662
 663    #[cfg(test)]
 664    pub fn reset(&self, id: ServerId) {
 665        self.teardown();
 666        *self.id.lock() = id;
 667        self.peer.reset(id.0 as u32);
 668        let _ = self.teardown.send(false);
 669    }
 670
 671    #[cfg(test)]
 672    pub fn id(&self) -> ServerId {
 673        *self.id.lock()
 674    }
 675
 676    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 677    where
 678        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 679        Fut: 'static + Send + Future<Output = Result<()>>,
 680        M: EnvelopedMessage,
 681    {
 682        let prev_handler = self.handlers.insert(
 683            TypeId::of::<M>(),
 684            Box::new(move |envelope, session, span| {
 685                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 686                let received_at = envelope.received_at;
 687                tracing::info!("message received");
 688                let start_time = Instant::now();
 689                let future = (handler)(
 690                    *envelope,
 691                    MessageContext {
 692                        session,
 693                        span: span.clone(),
 694                    },
 695                );
 696                async move {
 697                    let result = future.await;
 698                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 699                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 700                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 701                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 702                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 703                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 704                    match result {
 705                        Err(error) => {
 706                            tracing::error!(?error, "error handling message")
 707                        }
 708                        Ok(()) => tracing::info!("finished handling message"),
 709                    }
 710                }
 711                .boxed()
 712            }),
 713        );
 714        if prev_handler.is_some() {
 715            panic!("registered a handler for the same message twice");
 716        }
 717        self
 718    }
 719
 720    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 721    where
 722        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 723        Fut: 'static + Send + Future<Output = Result<()>>,
 724        M: EnvelopedMessage,
 725    {
 726        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 727        self
 728    }
 729
 730    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 731    where
 732        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 733        Fut: Send + Future<Output = Result<()>>,
 734        M: RequestMessage,
 735    {
 736        let handler = Arc::new(handler);
 737        self.add_handler(move |envelope, session| {
 738            let receipt = envelope.receipt();
 739            let handler = handler.clone();
 740            async move {
 741                let peer = session.peer.clone();
 742                let responded = Arc::new(AtomicBool::default());
 743                let response = Response {
 744                    peer: peer.clone(),
 745                    responded: responded.clone(),
 746                    receipt,
 747                };
 748                match (handler)(envelope.payload, response, session).await {
 749                    Ok(()) => {
 750                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 751                            Ok(())
 752                        } else {
 753                            Err(anyhow!("handler did not send a response"))?
 754                        }
 755                    }
 756                    Err(error) => {
 757                        let proto_err = match &error {
 758                            Error::Internal(err) => err.to_proto(),
 759                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 760                        };
 761                        peer.respond_with_error(receipt, proto_err)?;
 762                        Err(error)
 763                    }
 764                }
 765            }
 766        })
 767    }
 768
 769    pub fn handle_connection(
 770        self: &Arc<Self>,
 771        connection: Connection,
 772        address: String,
 773        principal: Principal,
 774        zed_version: ZedVersion,
 775        release_channel: Option<String>,
 776        user_agent: Option<String>,
 777        geoip_country_code: Option<String>,
 778        system_id: Option<String>,
 779        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 780        executor: Executor,
 781        connection_guard: Option<ConnectionGuard>,
 782    ) -> impl Future<Output = ()> + use<> {
 783        let this = self.clone();
 784        let span = info_span!("handle connection", %address,
 785            connection_id=field::Empty,
 786            user_id=field::Empty,
 787            login=field::Empty,
 788            impersonator=field::Empty,
 789            user_agent=field::Empty,
 790            geoip_country_code=field::Empty,
 791            release_channel=field::Empty,
 792        );
 793        principal.update_span(&span);
 794        if let Some(user_agent) = user_agent {
 795            span.record("user_agent", user_agent);
 796        }
 797        if let Some(release_channel) = release_channel {
 798            span.record("release_channel", release_channel);
 799        }
 800
 801        if let Some(country_code) = geoip_country_code.as_ref() {
 802            span.record("geoip_country_code", country_code);
 803        }
 804
 805        let mut teardown = self.teardown.subscribe();
 806        async move {
 807            if *teardown.borrow() {
 808                tracing::error!("server is tearing down");
 809                return;
 810            }
 811
 812            let (connection_id, handle_io, mut incoming_rx) =
 813                this.peer.add_connection(connection, {
 814                    let executor = executor.clone();
 815                    move |duration| executor.sleep(duration)
 816                });
 817            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 818
 819            tracing::info!("connection opened");
 820
 821            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 822            let http_client = match ReqwestClient::user_agent(&user_agent) {
 823                Ok(http_client) => Arc::new(http_client),
 824                Err(error) => {
 825                    tracing::error!(?error, "failed to create HTTP client");
 826                    return;
 827                }
 828            };
 829
 830            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 831                |supermaven_admin_api_key| {
 832                    Arc::new(SupermavenAdminApi::new(
 833                        supermaven_admin_api_key.to_string(),
 834                        http_client.clone(),
 835                    ))
 836                },
 837            );
 838
 839            let session = Session {
 840                principal: principal.clone(),
 841                connection_id,
 842                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 843                peer: this.peer.clone(),
 844                connection_pool: this.connection_pool.clone(),
 845                app_state: this.app_state.clone(),
 846                geoip_country_code,
 847                system_id,
 848                _executor: executor.clone(),
 849                supermaven_client,
 850            };
 851
 852            if let Err(error) = this
 853                .send_initial_client_update(
 854                    connection_id,
 855                    zed_version,
 856                    send_connection_id,
 857                    &session,
 858                )
 859                .await
 860            {
 861                tracing::error!(?error, "failed to send initial client update");
 862                return;
 863            }
 864            drop(connection_guard);
 865
 866            let handle_io = handle_io.fuse();
 867            futures::pin_mut!(handle_io);
 868
 869            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 870            // This prevents deadlocks when e.g., client A performs a request to client B and
 871            // client B performs a request to client A. If both clients stop processing further
 872            // messages until their respective request completes, they won't have a chance to
 873            // respond to the other client's request and cause a deadlock.
 874            //
 875            // This arrangement ensures we will attempt to process earlier messages first, but fall
 876            // back to processing messages arrived later in the spirit of making progress.
 877            const MAX_CONCURRENT_HANDLERS: usize = 256;
 878            let mut foreground_message_handlers = FuturesUnordered::new();
 879            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 880            let get_concurrent_handlers = {
 881                let concurrent_handlers = concurrent_handlers.clone();
 882                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 883            };
 884            loop {
 885                let next_message = async {
 886                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 887                    let message = incoming_rx.next().await;
 888                    // Cache the concurrent_handlers here, so that we know what the
 889                    // queue looks like as each handler starts
 890                    (permit, message, get_concurrent_handlers())
 891                }
 892                .fuse();
 893                futures::pin_mut!(next_message);
 894                futures::select_biased! {
 895                    _ = teardown.changed().fuse() => return,
 896                    result = handle_io => {
 897                        if let Err(error) = result {
 898                            tracing::error!(?error, "error handling I/O");
 899                        }
 900                        break;
 901                    }
 902                    _ = foreground_message_handlers.next() => {}
 903                    next_message = next_message => {
 904                        let (permit, message, concurrent_handlers) = next_message;
 905                        if let Some(message) = message {
 906                            let type_name = message.payload_type_name();
 907                            // note: we copy all the fields from the parent span so we can query them in the logs.
 908                            // (https://github.com/tokio-rs/tracing/issues/2670).
 909                            let span = tracing::info_span!("receive message",
 910                                %connection_id,
 911                                %address,
 912                                type_name,
 913                                concurrent_handlers,
 914                                user_id=field::Empty,
 915                                login=field::Empty,
 916                                impersonator=field::Empty,
 917                                // todo(lsp) remove after Zed Stable hits v0.204.x
 918                                multi_lsp_query_request=field::Empty,
 919                                lsp_query_request=field::Empty,
 920                                release_channel=field::Empty,
 921                                { TOTAL_DURATION_MS }=field::Empty,
 922                                { PROCESSING_DURATION_MS }=field::Empty,
 923                                { QUEUE_DURATION_MS }=field::Empty,
 924                                { HOST_WAITING_MS }=field::Empty
 925                            );
 926                            principal.update_span(&span);
 927                            let span_enter = span.enter();
 928                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 929                                let is_background = message.is_background();
 930                                let handle_message = (handler)(message, session.clone(), span.clone());
 931                                drop(span_enter);
 932
 933                                let handle_message = async move {
 934                                    handle_message.await;
 935                                    drop(permit);
 936                                }.instrument(span);
 937                                if is_background {
 938                                    executor.spawn_detached(handle_message);
 939                                } else {
 940                                    foreground_message_handlers.push(handle_message);
 941                                }
 942                            } else {
 943                                tracing::error!("no message handler");
 944                            }
 945                        } else {
 946                            tracing::info!("connection closed");
 947                            break;
 948                        }
 949                    }
 950                }
 951            }
 952
 953            drop(foreground_message_handlers);
 954            let concurrent_handlers = get_concurrent_handlers();
 955            tracing::info!(concurrent_handlers, "signing out");
 956            if let Err(error) = connection_lost(session, teardown, executor).await {
 957                tracing::error!(?error, "error signing out");
 958            }
 959        }
 960        .instrument(span)
 961    }
 962
 963    async fn send_initial_client_update(
 964        &self,
 965        connection_id: ConnectionId,
 966        zed_version: ZedVersion,
 967        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 968        session: &Session,
 969    ) -> Result<()> {
 970        self.peer.send(
 971            connection_id,
 972            proto::Hello {
 973                peer_id: Some(connection_id.into()),
 974            },
 975        )?;
 976        tracing::info!("sent hello message");
 977        if let Some(send_connection_id) = send_connection_id.take() {
 978            let _ = send_connection_id.send(connection_id);
 979        }
 980
 981        match &session.principal {
 982            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 983                if !user.connected_once {
 984                    self.peer.send(connection_id, proto::ShowContacts {})?;
 985                    self.app_state
 986                        .db
 987                        .set_user_connected_once(user.id, true)
 988                        .await?;
 989                }
 990
 991                let contacts = self.app_state.db.get_contacts(user.id).await?;
 992
 993                {
 994                    let mut pool = self.connection_pool.lock();
 995                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 996                    self.peer.send(
 997                        connection_id,
 998                        build_initial_contacts_update(contacts, &pool),
 999                    )?;
1000                }
1001
1002                if should_auto_subscribe_to_channels(zed_version) {
1003                    subscribe_user_to_channels(user.id, session).await?;
1004                }
1005
1006                if let Some(incoming_call) =
1007                    self.app_state.db.incoming_call_for_user(user.id).await?
1008                {
1009                    self.peer.send(connection_id, incoming_call)?;
1010                }
1011
1012                update_user_contacts(user.id, session).await?;
1013            }
1014        }
1015
1016        Ok(())
1017    }
1018
1019    pub async fn invite_code_redeemed(
1020        self: &Arc<Self>,
1021        inviter_id: UserId,
1022        invitee_id: UserId,
1023    ) -> Result<()> {
1024        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1025            && let Some(code) = &user.invite_code
1026        {
1027            let pool = self.connection_pool.lock();
1028            let invitee_contact = contact_for_user(invitee_id, false, &pool);
1029            for connection_id in pool.user_connection_ids(inviter_id) {
1030                self.peer.send(
1031                    connection_id,
1032                    proto::UpdateContacts {
1033                        contacts: vec![invitee_contact.clone()],
1034                        ..Default::default()
1035                    },
1036                )?;
1037                self.peer.send(
1038                    connection_id,
1039                    proto::UpdateInviteInfo {
1040                        url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1041                        count: user.invite_count as u32,
1042                    },
1043                )?;
1044            }
1045        }
1046        Ok(())
1047    }
1048
1049    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1050        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1051            && let Some(invite_code) = &user.invite_code
1052        {
1053            let pool = self.connection_pool.lock();
1054            for connection_id in pool.user_connection_ids(user_id) {
1055                self.peer.send(
1056                    connection_id,
1057                    proto::UpdateInviteInfo {
1058                        url: format!(
1059                            "{}{}",
1060                            self.app_state.config.invite_link_prefix, invite_code
1061                        ),
1062                        count: user.invite_count as u32,
1063                    },
1064                )?;
1065            }
1066        }
1067        Ok(())
1068    }
1069
1070    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1071        ServerSnapshot {
1072            connection_pool: ConnectionPoolGuard {
1073                guard: self.connection_pool.lock(),
1074                _not_send: PhantomData,
1075            },
1076            peer: &self.peer,
1077        }
1078    }
1079}
1080
1081impl Deref for ConnectionPoolGuard<'_> {
1082    type Target = ConnectionPool;
1083
1084    fn deref(&self) -> &Self::Target {
1085        &self.guard
1086    }
1087}
1088
1089impl DerefMut for ConnectionPoolGuard<'_> {
1090    fn deref_mut(&mut self) -> &mut Self::Target {
1091        &mut self.guard
1092    }
1093}
1094
1095impl Drop for ConnectionPoolGuard<'_> {
1096    fn drop(&mut self) {
1097        #[cfg(test)]
1098        self.check_invariants();
1099    }
1100}
1101
1102fn broadcast<F>(
1103    sender_id: Option<ConnectionId>,
1104    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1105    mut f: F,
1106) where
1107    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1108{
1109    for receiver_id in receiver_ids {
1110        if Some(receiver_id) != sender_id
1111            && let Err(error) = f(receiver_id)
1112        {
1113            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1114        }
1115    }
1116}
1117
1118pub struct ProtocolVersion(u32);
1119
1120impl Header for ProtocolVersion {
1121    fn name() -> &'static HeaderName {
1122        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1123        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1124    }
1125
1126    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1127    where
1128        Self: Sized,
1129        I: Iterator<Item = &'i axum::http::HeaderValue>,
1130    {
1131        let version = values
1132            .next()
1133            .ok_or_else(axum::headers::Error::invalid)?
1134            .to_str()
1135            .map_err(|_| axum::headers::Error::invalid())?
1136            .parse()
1137            .map_err(|_| axum::headers::Error::invalid())?;
1138        Ok(Self(version))
1139    }
1140
1141    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1142        values.extend([self.0.to_string().parse().unwrap()]);
1143    }
1144}
1145
1146pub struct AppVersionHeader(SemanticVersion);
1147impl Header for AppVersionHeader {
1148    fn name() -> &'static HeaderName {
1149        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1150        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1151    }
1152
1153    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1154    where
1155        Self: Sized,
1156        I: Iterator<Item = &'i axum::http::HeaderValue>,
1157    {
1158        let version = values
1159            .next()
1160            .ok_or_else(axum::headers::Error::invalid)?
1161            .to_str()
1162            .map_err(|_| axum::headers::Error::invalid())?
1163            .parse()
1164            .map_err(|_| axum::headers::Error::invalid())?;
1165        Ok(Self(version))
1166    }
1167
1168    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1169        values.extend([self.0.to_string().parse().unwrap()]);
1170    }
1171}
1172
1173#[derive(Debug)]
1174pub struct ReleaseChannelHeader(String);
1175
1176impl Header for ReleaseChannelHeader {
1177    fn name() -> &'static HeaderName {
1178        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1179        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1180    }
1181
1182    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1183    where
1184        Self: Sized,
1185        I: Iterator<Item = &'i axum::http::HeaderValue>,
1186    {
1187        Ok(Self(
1188            values
1189                .next()
1190                .ok_or_else(axum::headers::Error::invalid)?
1191                .to_str()
1192                .map_err(|_| axum::headers::Error::invalid())?
1193                .to_owned(),
1194        ))
1195    }
1196
1197    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1198        values.extend([self.0.parse().unwrap()]);
1199    }
1200}
1201
1202pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1203    Router::new()
1204        .route("/rpc", get(handle_websocket_request))
1205        .layer(
1206            ServiceBuilder::new()
1207                .layer(Extension(server.app_state.clone()))
1208                .layer(middleware::from_fn(auth::validate_header)),
1209        )
1210        .route("/metrics", get(handle_metrics))
1211        .layer(Extension(server))
1212}
1213
1214pub async fn handle_websocket_request(
1215    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1216    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1217    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1218    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1219    Extension(server): Extension<Arc<Server>>,
1220    Extension(principal): Extension<Principal>,
1221    user_agent: Option<TypedHeader<UserAgent>>,
1222    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1223    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1224    ws: WebSocketUpgrade,
1225) -> axum::response::Response {
1226    if protocol_version != rpc::PROTOCOL_VERSION {
1227        return (
1228            StatusCode::UPGRADE_REQUIRED,
1229            "client must be upgraded".to_string(),
1230        )
1231            .into_response();
1232    }
1233
1234    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1235        return (
1236            StatusCode::UPGRADE_REQUIRED,
1237            "no version header found".to_string(),
1238        )
1239            .into_response();
1240    };
1241
1242    let release_channel = release_channel_header.map(|header| header.0.0);
1243
1244    if !version.can_collaborate() {
1245        return (
1246            StatusCode::UPGRADE_REQUIRED,
1247            "client must be upgraded".to_string(),
1248        )
1249            .into_response();
1250    }
1251
1252    let socket_address = socket_address.to_string();
1253
1254    // Acquire connection guard before WebSocket upgrade
1255    let connection_guard = match ConnectionGuard::try_acquire() {
1256        Ok(guard) => guard,
1257        Err(()) => {
1258            return (
1259                StatusCode::SERVICE_UNAVAILABLE,
1260                "Too many concurrent connections",
1261            )
1262                .into_response();
1263        }
1264    };
1265
1266    ws.on_upgrade(move |socket| {
1267        let socket = socket
1268            .map_ok(to_tungstenite_message)
1269            .err_into()
1270            .with(|message| async move { to_axum_message(message) });
1271        let connection = Connection::new(Box::pin(socket));
1272        async move {
1273            server
1274                .handle_connection(
1275                    connection,
1276                    socket_address,
1277                    principal,
1278                    version,
1279                    release_channel,
1280                    user_agent.map(|header| header.to_string()),
1281                    country_code_header.map(|header| header.to_string()),
1282                    system_id_header.map(|header| header.to_string()),
1283                    None,
1284                    Executor::Production,
1285                    Some(connection_guard),
1286                )
1287                .await;
1288        }
1289    })
1290}
1291
1292pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1293    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1294    let connections_metric = CONNECTIONS_METRIC
1295        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1296
1297    let connections = server
1298        .connection_pool
1299        .lock()
1300        .connections()
1301        .filter(|connection| !connection.admin)
1302        .count();
1303    connections_metric.set(connections as _);
1304
1305    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1306    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1307        register_int_gauge!(
1308            "shared_projects",
1309            "number of open projects with one or more guests"
1310        )
1311        .unwrap()
1312    });
1313
1314    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1315    shared_projects_metric.set(shared_projects as _);
1316
1317    let encoder = prometheus::TextEncoder::new();
1318    let metric_families = prometheus::gather();
1319    let encoded_metrics = encoder
1320        .encode_to_string(&metric_families)
1321        .map_err(|err| anyhow!("{err}"))?;
1322    Ok(encoded_metrics)
1323}
1324
1325#[instrument(err, skip(executor))]
1326async fn connection_lost(
1327    session: Session,
1328    mut teardown: watch::Receiver<bool>,
1329    executor: Executor,
1330) -> Result<()> {
1331    session.peer.disconnect(session.connection_id);
1332    session
1333        .connection_pool()
1334        .await
1335        .remove_connection(session.connection_id)?;
1336
1337    session
1338        .db()
1339        .await
1340        .connection_lost(session.connection_id)
1341        .await
1342        .trace_err();
1343
1344    futures::select_biased! {
1345        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1346
1347            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1348            leave_room_for_session(&session, session.connection_id).await.trace_err();
1349            leave_channel_buffers_for_session(&session)
1350                .await
1351                .trace_err();
1352
1353            if !session
1354                .connection_pool()
1355                .await
1356                .is_user_online(session.user_id())
1357            {
1358                let db = session.db().await;
1359                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1360                    room_updated(&room, &session.peer);
1361                }
1362            }
1363
1364            update_user_contacts(session.user_id(), &session).await?;
1365        },
1366        _ = teardown.changed().fuse() => {}
1367    }
1368
1369    Ok(())
1370}
1371
1372/// Acknowledges a ping from a client, used to keep the connection alive.
1373async fn ping(
1374    _: proto::Ping,
1375    response: Response<proto::Ping>,
1376    _session: MessageContext,
1377) -> Result<()> {
1378    response.send(proto::Ack {})?;
1379    Ok(())
1380}
1381
1382/// Creates a new room for calling (outside of channels)
1383async fn create_room(
1384    _request: proto::CreateRoom,
1385    response: Response<proto::CreateRoom>,
1386    session: MessageContext,
1387) -> Result<()> {
1388    let livekit_room = nanoid::nanoid!(30);
1389
1390    let live_kit_connection_info = util::maybe!(async {
1391        let live_kit = session.app_state.livekit_client.as_ref();
1392        let live_kit = live_kit?;
1393        let user_id = session.user_id().to_string();
1394
1395        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1396
1397        Some(proto::LiveKitConnectionInfo {
1398            server_url: live_kit.url().into(),
1399            token,
1400            can_publish: true,
1401        })
1402    })
1403    .await;
1404
1405    let room = session
1406        .db()
1407        .await
1408        .create_room(session.user_id(), session.connection_id, &livekit_room)
1409        .await?;
1410
1411    response.send(proto::CreateRoomResponse {
1412        room: Some(room.clone()),
1413        live_kit_connection_info,
1414    })?;
1415
1416    update_user_contacts(session.user_id(), &session).await?;
1417    Ok(())
1418}
1419
1420/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1421async fn join_room(
1422    request: proto::JoinRoom,
1423    response: Response<proto::JoinRoom>,
1424    session: MessageContext,
1425) -> Result<()> {
1426    let room_id = RoomId::from_proto(request.id);
1427
1428    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1429
1430    if let Some(channel_id) = channel_id {
1431        return join_channel_internal(channel_id, Box::new(response), session).await;
1432    }
1433
1434    let joined_room = {
1435        let room = session
1436            .db()
1437            .await
1438            .join_room(room_id, session.user_id(), session.connection_id)
1439            .await?;
1440        room_updated(&room.room, &session.peer);
1441        room.into_inner()
1442    };
1443
1444    for connection_id in session
1445        .connection_pool()
1446        .await
1447        .user_connection_ids(session.user_id())
1448    {
1449        session
1450            .peer
1451            .send(
1452                connection_id,
1453                proto::CallCanceled {
1454                    room_id: room_id.to_proto(),
1455                },
1456            )
1457            .trace_err();
1458    }
1459
1460    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1461    {
1462        live_kit
1463            .room_token(
1464                &joined_room.room.livekit_room,
1465                &session.user_id().to_string(),
1466            )
1467            .trace_err()
1468            .map(|token| proto::LiveKitConnectionInfo {
1469                server_url: live_kit.url().into(),
1470                token,
1471                can_publish: true,
1472            })
1473    } else {
1474        None
1475    };
1476
1477    response.send(proto::JoinRoomResponse {
1478        room: Some(joined_room.room),
1479        channel_id: None,
1480        live_kit_connection_info,
1481    })?;
1482
1483    update_user_contacts(session.user_id(), &session).await?;
1484    Ok(())
1485}
1486
1487/// Rejoin room is used to reconnect to a room after connection errors.
1488async fn rejoin_room(
1489    request: proto::RejoinRoom,
1490    response: Response<proto::RejoinRoom>,
1491    session: MessageContext,
1492) -> Result<()> {
1493    let room;
1494    let channel;
1495    {
1496        let mut rejoined_room = session
1497            .db()
1498            .await
1499            .rejoin_room(request, session.user_id(), session.connection_id)
1500            .await?;
1501
1502        response.send(proto::RejoinRoomResponse {
1503            room: Some(rejoined_room.room.clone()),
1504            reshared_projects: rejoined_room
1505                .reshared_projects
1506                .iter()
1507                .map(|project| proto::ResharedProject {
1508                    id: project.id.to_proto(),
1509                    collaborators: project
1510                        .collaborators
1511                        .iter()
1512                        .map(|collaborator| collaborator.to_proto())
1513                        .collect(),
1514                })
1515                .collect(),
1516            rejoined_projects: rejoined_room
1517                .rejoined_projects
1518                .iter()
1519                .map(|rejoined_project| rejoined_project.to_proto())
1520                .collect(),
1521        })?;
1522        room_updated(&rejoined_room.room, &session.peer);
1523
1524        for project in &rejoined_room.reshared_projects {
1525            for collaborator in &project.collaborators {
1526                session
1527                    .peer
1528                    .send(
1529                        collaborator.connection_id,
1530                        proto::UpdateProjectCollaborator {
1531                            project_id: project.id.to_proto(),
1532                            old_peer_id: Some(project.old_connection_id.into()),
1533                            new_peer_id: Some(session.connection_id.into()),
1534                        },
1535                    )
1536                    .trace_err();
1537            }
1538
1539            broadcast(
1540                Some(session.connection_id),
1541                project
1542                    .collaborators
1543                    .iter()
1544                    .map(|collaborator| collaborator.connection_id),
1545                |connection_id| {
1546                    session.peer.forward_send(
1547                        session.connection_id,
1548                        connection_id,
1549                        proto::UpdateProject {
1550                            project_id: project.id.to_proto(),
1551                            worktrees: project.worktrees.clone(),
1552                        },
1553                    )
1554                },
1555            );
1556        }
1557
1558        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1559
1560        let rejoined_room = rejoined_room.into_inner();
1561
1562        room = rejoined_room.room;
1563        channel = rejoined_room.channel;
1564    }
1565
1566    if let Some(channel) = channel {
1567        channel_updated(
1568            &channel,
1569            &room,
1570            &session.peer,
1571            &*session.connection_pool().await,
1572        );
1573    }
1574
1575    update_user_contacts(session.user_id(), &session).await?;
1576    Ok(())
1577}
1578
1579fn notify_rejoined_projects(
1580    rejoined_projects: &mut Vec<RejoinedProject>,
1581    session: &Session,
1582) -> Result<()> {
1583    for project in rejoined_projects.iter() {
1584        for collaborator in &project.collaborators {
1585            session
1586                .peer
1587                .send(
1588                    collaborator.connection_id,
1589                    proto::UpdateProjectCollaborator {
1590                        project_id: project.id.to_proto(),
1591                        old_peer_id: Some(project.old_connection_id.into()),
1592                        new_peer_id: Some(session.connection_id.into()),
1593                    },
1594                )
1595                .trace_err();
1596        }
1597    }
1598
1599    for project in rejoined_projects {
1600        for worktree in mem::take(&mut project.worktrees) {
1601            // Stream this worktree's entries.
1602            let message = proto::UpdateWorktree {
1603                project_id: project.id.to_proto(),
1604                worktree_id: worktree.id,
1605                abs_path: worktree.abs_path.clone(),
1606                root_name: worktree.root_name,
1607                updated_entries: worktree.updated_entries,
1608                removed_entries: worktree.removed_entries,
1609                scan_id: worktree.scan_id,
1610                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1611                updated_repositories: worktree.updated_repositories,
1612                removed_repositories: worktree.removed_repositories,
1613            };
1614            for update in proto::split_worktree_update(message) {
1615                session.peer.send(session.connection_id, update)?;
1616            }
1617
1618            // Stream this worktree's diagnostics.
1619            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1620            if let Some(summary) = worktree_diagnostics.next() {
1621                let message = proto::UpdateDiagnosticSummary {
1622                    project_id: project.id.to_proto(),
1623                    worktree_id: worktree.id,
1624                    summary: Some(summary),
1625                    more_summaries: worktree_diagnostics.collect(),
1626                };
1627                session.peer.send(session.connection_id, message)?;
1628            }
1629
1630            for settings_file in worktree.settings_files {
1631                session.peer.send(
1632                    session.connection_id,
1633                    proto::UpdateWorktreeSettings {
1634                        project_id: project.id.to_proto(),
1635                        worktree_id: worktree.id,
1636                        path: settings_file.path,
1637                        content: Some(settings_file.content),
1638                        kind: Some(settings_file.kind.to_proto().into()),
1639                    },
1640                )?;
1641            }
1642        }
1643
1644        for repository in mem::take(&mut project.updated_repositories) {
1645            for update in split_repository_update(repository) {
1646                session.peer.send(session.connection_id, update)?;
1647            }
1648        }
1649
1650        for id in mem::take(&mut project.removed_repositories) {
1651            session.peer.send(
1652                session.connection_id,
1653                proto::RemoveRepository {
1654                    project_id: project.id.to_proto(),
1655                    id,
1656                },
1657            )?;
1658        }
1659    }
1660
1661    Ok(())
1662}
1663
1664/// leave room disconnects from the room.
1665async fn leave_room(
1666    _: proto::LeaveRoom,
1667    response: Response<proto::LeaveRoom>,
1668    session: MessageContext,
1669) -> Result<()> {
1670    leave_room_for_session(&session, session.connection_id).await?;
1671    response.send(proto::Ack {})?;
1672    Ok(())
1673}
1674
1675/// Updates the permissions of someone else in the room.
1676async fn set_room_participant_role(
1677    request: proto::SetRoomParticipantRole,
1678    response: Response<proto::SetRoomParticipantRole>,
1679    session: MessageContext,
1680) -> Result<()> {
1681    let user_id = UserId::from_proto(request.user_id);
1682    let role = ChannelRole::from(request.role());
1683
1684    let (livekit_room, can_publish) = {
1685        let room = session
1686            .db()
1687            .await
1688            .set_room_participant_role(
1689                session.user_id(),
1690                RoomId::from_proto(request.room_id),
1691                user_id,
1692                role,
1693            )
1694            .await?;
1695
1696        let livekit_room = room.livekit_room.clone();
1697        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1698        room_updated(&room, &session.peer);
1699        (livekit_room, can_publish)
1700    };
1701
1702    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1703        live_kit
1704            .update_participant(
1705                livekit_room.clone(),
1706                request.user_id.to_string(),
1707                livekit_api::proto::ParticipantPermission {
1708                    can_subscribe: true,
1709                    can_publish,
1710                    can_publish_data: can_publish,
1711                    hidden: false,
1712                    recorder: false,
1713                },
1714            )
1715            .await
1716            .trace_err();
1717    }
1718
1719    response.send(proto::Ack {})?;
1720    Ok(())
1721}
1722
1723/// Call someone else into the current room
1724async fn call(
1725    request: proto::Call,
1726    response: Response<proto::Call>,
1727    session: MessageContext,
1728) -> Result<()> {
1729    let room_id = RoomId::from_proto(request.room_id);
1730    let calling_user_id = session.user_id();
1731    let calling_connection_id = session.connection_id;
1732    let called_user_id = UserId::from_proto(request.called_user_id);
1733    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1734    if !session
1735        .db()
1736        .await
1737        .has_contact(calling_user_id, called_user_id)
1738        .await?
1739    {
1740        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1741    }
1742
1743    let incoming_call = {
1744        let (room, incoming_call) = &mut *session
1745            .db()
1746            .await
1747            .call(
1748                room_id,
1749                calling_user_id,
1750                calling_connection_id,
1751                called_user_id,
1752                initial_project_id,
1753            )
1754            .await?;
1755        room_updated(room, &session.peer);
1756        mem::take(incoming_call)
1757    };
1758    update_user_contacts(called_user_id, &session).await?;
1759
1760    let mut calls = session
1761        .connection_pool()
1762        .await
1763        .user_connection_ids(called_user_id)
1764        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1765        .collect::<FuturesUnordered<_>>();
1766
1767    while let Some(call_response) = calls.next().await {
1768        match call_response.as_ref() {
1769            Ok(_) => {
1770                response.send(proto::Ack {})?;
1771                return Ok(());
1772            }
1773            Err(_) => {
1774                call_response.trace_err();
1775            }
1776        }
1777    }
1778
1779    {
1780        let room = session
1781            .db()
1782            .await
1783            .call_failed(room_id, called_user_id)
1784            .await?;
1785        room_updated(&room, &session.peer);
1786    }
1787    update_user_contacts(called_user_id, &session).await?;
1788
1789    Err(anyhow!("failed to ring user"))?
1790}
1791
1792/// Cancel an outgoing call.
1793async fn cancel_call(
1794    request: proto::CancelCall,
1795    response: Response<proto::CancelCall>,
1796    session: MessageContext,
1797) -> Result<()> {
1798    let called_user_id = UserId::from_proto(request.called_user_id);
1799    let room_id = RoomId::from_proto(request.room_id);
1800    {
1801        let room = session
1802            .db()
1803            .await
1804            .cancel_call(room_id, session.connection_id, called_user_id)
1805            .await?;
1806        room_updated(&room, &session.peer);
1807    }
1808
1809    for connection_id in session
1810        .connection_pool()
1811        .await
1812        .user_connection_ids(called_user_id)
1813    {
1814        session
1815            .peer
1816            .send(
1817                connection_id,
1818                proto::CallCanceled {
1819                    room_id: room_id.to_proto(),
1820                },
1821            )
1822            .trace_err();
1823    }
1824    response.send(proto::Ack {})?;
1825
1826    update_user_contacts(called_user_id, &session).await?;
1827    Ok(())
1828}
1829
1830/// Decline an incoming call.
1831async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1832    let room_id = RoomId::from_proto(message.room_id);
1833    {
1834        let room = session
1835            .db()
1836            .await
1837            .decline_call(Some(room_id), session.user_id())
1838            .await?
1839            .context("declining call")?;
1840        room_updated(&room, &session.peer);
1841    }
1842
1843    for connection_id in session
1844        .connection_pool()
1845        .await
1846        .user_connection_ids(session.user_id())
1847    {
1848        session
1849            .peer
1850            .send(
1851                connection_id,
1852                proto::CallCanceled {
1853                    room_id: room_id.to_proto(),
1854                },
1855            )
1856            .trace_err();
1857    }
1858    update_user_contacts(session.user_id(), &session).await?;
1859    Ok(())
1860}
1861
1862/// Updates other participants in the room with your current location.
1863async fn update_participant_location(
1864    request: proto::UpdateParticipantLocation,
1865    response: Response<proto::UpdateParticipantLocation>,
1866    session: MessageContext,
1867) -> Result<()> {
1868    let room_id = RoomId::from_proto(request.room_id);
1869    let location = request.location.context("invalid location")?;
1870
1871    let db = session.db().await;
1872    let room = db
1873        .update_room_participant_location(room_id, session.connection_id, location)
1874        .await?;
1875
1876    room_updated(&room, &session.peer);
1877    response.send(proto::Ack {})?;
1878    Ok(())
1879}
1880
1881/// Share a project into the room.
1882async fn share_project(
1883    request: proto::ShareProject,
1884    response: Response<proto::ShareProject>,
1885    session: MessageContext,
1886) -> Result<()> {
1887    let (project_id, room) = &*session
1888        .db()
1889        .await
1890        .share_project(
1891            RoomId::from_proto(request.room_id),
1892            session.connection_id,
1893            &request.worktrees,
1894            request.is_ssh_project,
1895        )
1896        .await?;
1897    response.send(proto::ShareProjectResponse {
1898        project_id: project_id.to_proto(),
1899    })?;
1900    room_updated(room, &session.peer);
1901
1902    Ok(())
1903}
1904
1905/// Unshare a project from the room.
1906async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1907    let project_id = ProjectId::from_proto(message.project_id);
1908    unshare_project_internal(project_id, session.connection_id, &session).await
1909}
1910
1911async fn unshare_project_internal(
1912    project_id: ProjectId,
1913    connection_id: ConnectionId,
1914    session: &Session,
1915) -> Result<()> {
1916    let delete = {
1917        let room_guard = session
1918            .db()
1919            .await
1920            .unshare_project(project_id, connection_id)
1921            .await?;
1922
1923        let (delete, room, guest_connection_ids) = &*room_guard;
1924
1925        let message = proto::UnshareProject {
1926            project_id: project_id.to_proto(),
1927        };
1928
1929        broadcast(
1930            Some(connection_id),
1931            guest_connection_ids.iter().copied(),
1932            |conn_id| session.peer.send(conn_id, message.clone()),
1933        );
1934        if let Some(room) = room {
1935            room_updated(room, &session.peer);
1936        }
1937
1938        *delete
1939    };
1940
1941    if delete {
1942        let db = session.db().await;
1943        db.delete_project(project_id).await?;
1944    }
1945
1946    Ok(())
1947}
1948
1949/// Join someone elses shared project.
1950async fn join_project(
1951    request: proto::JoinProject,
1952    response: Response<proto::JoinProject>,
1953    session: MessageContext,
1954) -> Result<()> {
1955    let project_id = ProjectId::from_proto(request.project_id);
1956
1957    tracing::info!(%project_id, "join project");
1958
1959    let db = session.db().await;
1960    let (project, replica_id) = &mut *db
1961        .join_project(
1962            project_id,
1963            session.connection_id,
1964            session.user_id(),
1965            request.committer_name.clone(),
1966            request.committer_email.clone(),
1967        )
1968        .await?;
1969    drop(db);
1970    tracing::info!(%project_id, "join remote project");
1971    let collaborators = project
1972        .collaborators
1973        .iter()
1974        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1975        .map(|collaborator| collaborator.to_proto())
1976        .collect::<Vec<_>>();
1977    let project_id = project.id;
1978    let guest_user_id = session.user_id();
1979
1980    let worktrees = project
1981        .worktrees
1982        .iter()
1983        .map(|(id, worktree)| proto::WorktreeMetadata {
1984            id: *id,
1985            root_name: worktree.root_name.clone(),
1986            visible: worktree.visible,
1987            abs_path: worktree.abs_path.clone(),
1988        })
1989        .collect::<Vec<_>>();
1990
1991    let add_project_collaborator = proto::AddProjectCollaborator {
1992        project_id: project_id.to_proto(),
1993        collaborator: Some(proto::Collaborator {
1994            peer_id: Some(session.connection_id.into()),
1995            replica_id: replica_id.0 as u32,
1996            user_id: guest_user_id.to_proto(),
1997            is_host: false,
1998            committer_name: request.committer_name.clone(),
1999            committer_email: request.committer_email.clone(),
2000        }),
2001    };
2002
2003    for collaborator in &collaborators {
2004        session
2005            .peer
2006            .send(
2007                collaborator.peer_id.unwrap().into(),
2008                add_project_collaborator.clone(),
2009            )
2010            .trace_err();
2011    }
2012
2013    // First, we send the metadata associated with each worktree.
2014    let (language_servers, language_server_capabilities) = project
2015        .language_servers
2016        .clone()
2017        .into_iter()
2018        .map(|server| (server.server, server.capabilities))
2019        .unzip();
2020    response.send(proto::JoinProjectResponse {
2021        project_id: project.id.0 as u64,
2022        worktrees,
2023        replica_id: replica_id.0 as u32,
2024        collaborators,
2025        language_servers,
2026        language_server_capabilities,
2027        role: project.role.into(),
2028    })?;
2029
2030    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2031        // Stream this worktree's entries.
2032        let message = proto::UpdateWorktree {
2033            project_id: project_id.to_proto(),
2034            worktree_id,
2035            abs_path: worktree.abs_path.clone(),
2036            root_name: worktree.root_name,
2037            updated_entries: worktree.entries,
2038            removed_entries: Default::default(),
2039            scan_id: worktree.scan_id,
2040            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2041            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2042            removed_repositories: Default::default(),
2043        };
2044        for update in proto::split_worktree_update(message) {
2045            session.peer.send(session.connection_id, update.clone())?;
2046        }
2047
2048        // Stream this worktree's diagnostics.
2049        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2050        if let Some(summary) = worktree_diagnostics.next() {
2051            let message = proto::UpdateDiagnosticSummary {
2052                project_id: project.id.to_proto(),
2053                worktree_id: worktree.id,
2054                summary: Some(summary),
2055                more_summaries: worktree_diagnostics.collect(),
2056            };
2057            session.peer.send(session.connection_id, message)?;
2058        }
2059
2060        for settings_file in worktree.settings_files {
2061            session.peer.send(
2062                session.connection_id,
2063                proto::UpdateWorktreeSettings {
2064                    project_id: project_id.to_proto(),
2065                    worktree_id: worktree.id,
2066                    path: settings_file.path,
2067                    content: Some(settings_file.content),
2068                    kind: Some(settings_file.kind.to_proto() as i32),
2069                },
2070            )?;
2071        }
2072    }
2073
2074    for repository in mem::take(&mut project.repositories) {
2075        for update in split_repository_update(repository) {
2076            session.peer.send(session.connection_id, update)?;
2077        }
2078    }
2079
2080    for language_server in &project.language_servers {
2081        session.peer.send(
2082            session.connection_id,
2083            proto::UpdateLanguageServer {
2084                project_id: project_id.to_proto(),
2085                server_name: Some(language_server.server.name.clone()),
2086                language_server_id: language_server.server.id,
2087                variant: Some(
2088                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2089                        proto::LspDiskBasedDiagnosticsUpdated {},
2090                    ),
2091                ),
2092            },
2093        )?;
2094    }
2095
2096    Ok(())
2097}
2098
2099/// Leave someone elses shared project.
2100async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2101    let sender_id = session.connection_id;
2102    let project_id = ProjectId::from_proto(request.project_id);
2103    let db = session.db().await;
2104
2105    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2106    tracing::info!(
2107        %project_id,
2108        "leave project"
2109    );
2110
2111    project_left(project, &session);
2112    if let Some(room) = room {
2113        room_updated(room, &session.peer);
2114    }
2115
2116    Ok(())
2117}
2118
2119/// Updates other participants with changes to the project
2120async fn update_project(
2121    request: proto::UpdateProject,
2122    response: Response<proto::UpdateProject>,
2123    session: MessageContext,
2124) -> Result<()> {
2125    let project_id = ProjectId::from_proto(request.project_id);
2126    let (room, guest_connection_ids) = &*session
2127        .db()
2128        .await
2129        .update_project(project_id, session.connection_id, &request.worktrees)
2130        .await?;
2131    broadcast(
2132        Some(session.connection_id),
2133        guest_connection_ids.iter().copied(),
2134        |connection_id| {
2135            session
2136                .peer
2137                .forward_send(session.connection_id, connection_id, request.clone())
2138        },
2139    );
2140    if let Some(room) = room {
2141        room_updated(room, &session.peer);
2142    }
2143    response.send(proto::Ack {})?;
2144
2145    Ok(())
2146}
2147
2148/// Updates other participants with changes to the worktree
2149async fn update_worktree(
2150    request: proto::UpdateWorktree,
2151    response: Response<proto::UpdateWorktree>,
2152    session: MessageContext,
2153) -> Result<()> {
2154    let guest_connection_ids = session
2155        .db()
2156        .await
2157        .update_worktree(&request, session.connection_id)
2158        .await?;
2159
2160    broadcast(
2161        Some(session.connection_id),
2162        guest_connection_ids.iter().copied(),
2163        |connection_id| {
2164            session
2165                .peer
2166                .forward_send(session.connection_id, connection_id, request.clone())
2167        },
2168    );
2169    response.send(proto::Ack {})?;
2170    Ok(())
2171}
2172
2173async fn update_repository(
2174    request: proto::UpdateRepository,
2175    response: Response<proto::UpdateRepository>,
2176    session: MessageContext,
2177) -> Result<()> {
2178    let guest_connection_ids = session
2179        .db()
2180        .await
2181        .update_repository(&request, session.connection_id)
2182        .await?;
2183
2184    broadcast(
2185        Some(session.connection_id),
2186        guest_connection_ids.iter().copied(),
2187        |connection_id| {
2188            session
2189                .peer
2190                .forward_send(session.connection_id, connection_id, request.clone())
2191        },
2192    );
2193    response.send(proto::Ack {})?;
2194    Ok(())
2195}
2196
2197async fn remove_repository(
2198    request: proto::RemoveRepository,
2199    response: Response<proto::RemoveRepository>,
2200    session: MessageContext,
2201) -> Result<()> {
2202    let guest_connection_ids = session
2203        .db()
2204        .await
2205        .remove_repository(&request, session.connection_id)
2206        .await?;
2207
2208    broadcast(
2209        Some(session.connection_id),
2210        guest_connection_ids.iter().copied(),
2211        |connection_id| {
2212            session
2213                .peer
2214                .forward_send(session.connection_id, connection_id, request.clone())
2215        },
2216    );
2217    response.send(proto::Ack {})?;
2218    Ok(())
2219}
2220
2221/// Updates other participants with changes to the diagnostics
2222async fn update_diagnostic_summary(
2223    message: proto::UpdateDiagnosticSummary,
2224    session: MessageContext,
2225) -> Result<()> {
2226    let guest_connection_ids = session
2227        .db()
2228        .await
2229        .update_diagnostic_summary(&message, session.connection_id)
2230        .await?;
2231
2232    broadcast(
2233        Some(session.connection_id),
2234        guest_connection_ids.iter().copied(),
2235        |connection_id| {
2236            session
2237                .peer
2238                .forward_send(session.connection_id, connection_id, message.clone())
2239        },
2240    );
2241
2242    Ok(())
2243}
2244
2245/// Updates other participants with changes to the worktree settings
2246async fn update_worktree_settings(
2247    message: proto::UpdateWorktreeSettings,
2248    session: MessageContext,
2249) -> Result<()> {
2250    let guest_connection_ids = session
2251        .db()
2252        .await
2253        .update_worktree_settings(&message, session.connection_id)
2254        .await?;
2255
2256    broadcast(
2257        Some(session.connection_id),
2258        guest_connection_ids.iter().copied(),
2259        |connection_id| {
2260            session
2261                .peer
2262                .forward_send(session.connection_id, connection_id, message.clone())
2263        },
2264    );
2265
2266    Ok(())
2267}
2268
2269/// Notify other participants that a language server has started.
2270async fn start_language_server(
2271    request: proto::StartLanguageServer,
2272    session: MessageContext,
2273) -> Result<()> {
2274    let guest_connection_ids = session
2275        .db()
2276        .await
2277        .start_language_server(&request, session.connection_id)
2278        .await?;
2279
2280    broadcast(
2281        Some(session.connection_id),
2282        guest_connection_ids.iter().copied(),
2283        |connection_id| {
2284            session
2285                .peer
2286                .forward_send(session.connection_id, connection_id, request.clone())
2287        },
2288    );
2289    Ok(())
2290}
2291
2292/// Notify other participants that a language server has changed.
2293async fn update_language_server(
2294    request: proto::UpdateLanguageServer,
2295    session: MessageContext,
2296) -> Result<()> {
2297    let project_id = ProjectId::from_proto(request.project_id);
2298    let db = session.db().await;
2299
2300    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2301        && let Some(capabilities) = update.capabilities.clone()
2302    {
2303        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2304            .await?;
2305    }
2306
2307    let project_connection_ids = db
2308        .project_connection_ids(project_id, session.connection_id, true)
2309        .await?;
2310    broadcast(
2311        Some(session.connection_id),
2312        project_connection_ids.iter().copied(),
2313        |connection_id| {
2314            session
2315                .peer
2316                .forward_send(session.connection_id, connection_id, request.clone())
2317        },
2318    );
2319    Ok(())
2320}
2321
2322/// forward a project request to the host. These requests should be read only
2323/// as guests are allowed to send them.
2324async fn forward_read_only_project_request<T>(
2325    request: T,
2326    response: Response<T>,
2327    session: MessageContext,
2328) -> Result<()>
2329where
2330    T: EntityMessage + RequestMessage,
2331{
2332    let project_id = ProjectId::from_proto(request.remote_entity_id());
2333    let host_connection_id = session
2334        .db()
2335        .await
2336        .host_for_read_only_project_request(project_id, session.connection_id)
2337        .await?;
2338    let payload = session.forward_request(host_connection_id, request).await?;
2339    response.send(payload)?;
2340    Ok(())
2341}
2342
2343/// forward a project request to the host. These requests are disallowed
2344/// for guests.
2345async fn forward_mutating_project_request<T>(
2346    request: T,
2347    response: Response<T>,
2348    session: MessageContext,
2349) -> Result<()>
2350where
2351    T: EntityMessage + RequestMessage,
2352{
2353    let project_id = ProjectId::from_proto(request.remote_entity_id());
2354
2355    let host_connection_id = session
2356        .db()
2357        .await
2358        .host_for_mutating_project_request(project_id, session.connection_id)
2359        .await?;
2360    let payload = session.forward_request(host_connection_id, request).await?;
2361    response.send(payload)?;
2362    Ok(())
2363}
2364
2365// todo(lsp) remove after Zed Stable hits v0.204.x
2366async fn multi_lsp_query(
2367    request: MultiLspQuery,
2368    response: Response<MultiLspQuery>,
2369    session: MessageContext,
2370) -> Result<()> {
2371    tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2372    tracing::info!("multi_lsp_query message received");
2373    forward_mutating_project_request(request, response, session).await
2374}
2375
2376async fn lsp_query(
2377    request: proto::LspQuery,
2378    response: Response<proto::LspQuery>,
2379    session: MessageContext,
2380) -> Result<()> {
2381    let (name, should_write) = request.query_name_and_write_permissions();
2382    tracing::Span::current().record("lsp_query_request", name);
2383    tracing::info!("lsp_query message received");
2384    if should_write {
2385        forward_mutating_project_request(request, response, session).await
2386    } else {
2387        forward_read_only_project_request(request, response, session).await
2388    }
2389}
2390
2391/// Notify other participants that a new buffer has been created
2392async fn create_buffer_for_peer(
2393    request: proto::CreateBufferForPeer,
2394    session: MessageContext,
2395) -> Result<()> {
2396    session
2397        .db()
2398        .await
2399        .check_user_is_project_host(
2400            ProjectId::from_proto(request.project_id),
2401            session.connection_id,
2402        )
2403        .await?;
2404    let peer_id = request.peer_id.context("invalid peer id")?;
2405    session
2406        .peer
2407        .forward_send(session.connection_id, peer_id.into(), request)?;
2408    Ok(())
2409}
2410
2411/// Notify other participants that a buffer has been updated. This is
2412/// allowed for guests as long as the update is limited to selections.
2413async fn update_buffer(
2414    request: proto::UpdateBuffer,
2415    response: Response<proto::UpdateBuffer>,
2416    session: MessageContext,
2417) -> Result<()> {
2418    let project_id = ProjectId::from_proto(request.project_id);
2419    let mut capability = Capability::ReadOnly;
2420
2421    for op in request.operations.iter() {
2422        match op.variant {
2423            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2424            Some(_) => capability = Capability::ReadWrite,
2425        }
2426    }
2427
2428    let host = {
2429        let guard = session
2430            .db()
2431            .await
2432            .connections_for_buffer_update(project_id, session.connection_id, capability)
2433            .await?;
2434
2435        let (host, guests) = &*guard;
2436
2437        broadcast(
2438            Some(session.connection_id),
2439            guests.clone(),
2440            |connection_id| {
2441                session
2442                    .peer
2443                    .forward_send(session.connection_id, connection_id, request.clone())
2444            },
2445        );
2446
2447        *host
2448    };
2449
2450    if host != session.connection_id {
2451        session.forward_request(host, request.clone()).await?;
2452    }
2453
2454    response.send(proto::Ack {})?;
2455    Ok(())
2456}
2457
2458async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2459    let project_id = ProjectId::from_proto(message.project_id);
2460
2461    let operation = message.operation.as_ref().context("invalid operation")?;
2462    let capability = match operation.variant.as_ref() {
2463        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2464            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2465                match buffer_op.variant {
2466                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2467                        Capability::ReadOnly
2468                    }
2469                    _ => Capability::ReadWrite,
2470                }
2471            } else {
2472                Capability::ReadWrite
2473            }
2474        }
2475        Some(_) => Capability::ReadWrite,
2476        None => Capability::ReadOnly,
2477    };
2478
2479    let guard = session
2480        .db()
2481        .await
2482        .connections_for_buffer_update(project_id, session.connection_id, capability)
2483        .await?;
2484
2485    let (host, guests) = &*guard;
2486
2487    broadcast(
2488        Some(session.connection_id),
2489        guests.iter().chain([host]).copied(),
2490        |connection_id| {
2491            session
2492                .peer
2493                .forward_send(session.connection_id, connection_id, message.clone())
2494        },
2495    );
2496
2497    Ok(())
2498}
2499
2500/// Notify other participants that a project has been updated.
2501async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2502    request: T,
2503    session: MessageContext,
2504) -> Result<()> {
2505    let project_id = ProjectId::from_proto(request.remote_entity_id());
2506    let project_connection_ids = session
2507        .db()
2508        .await
2509        .project_connection_ids(project_id, session.connection_id, false)
2510        .await?;
2511
2512    broadcast(
2513        Some(session.connection_id),
2514        project_connection_ids.iter().copied(),
2515        |connection_id| {
2516            session
2517                .peer
2518                .forward_send(session.connection_id, connection_id, request.clone())
2519        },
2520    );
2521    Ok(())
2522}
2523
2524/// Start following another user in a call.
2525async fn follow(
2526    request: proto::Follow,
2527    response: Response<proto::Follow>,
2528    session: MessageContext,
2529) -> Result<()> {
2530    let room_id = RoomId::from_proto(request.room_id);
2531    let project_id = request.project_id.map(ProjectId::from_proto);
2532    let leader_id = request.leader_id.context("invalid leader id")?.into();
2533    let follower_id = session.connection_id;
2534
2535    session
2536        .db()
2537        .await
2538        .check_room_participants(room_id, leader_id, session.connection_id)
2539        .await?;
2540
2541    let response_payload = session.forward_request(leader_id, request).await?;
2542    response.send(response_payload)?;
2543
2544    if let Some(project_id) = project_id {
2545        let room = session
2546            .db()
2547            .await
2548            .follow(room_id, project_id, leader_id, follower_id)
2549            .await?;
2550        room_updated(&room, &session.peer);
2551    }
2552
2553    Ok(())
2554}
2555
2556/// Stop following another user in a call.
2557async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2558    let room_id = RoomId::from_proto(request.room_id);
2559    let project_id = request.project_id.map(ProjectId::from_proto);
2560    let leader_id = request.leader_id.context("invalid leader id")?.into();
2561    let follower_id = session.connection_id;
2562
2563    session
2564        .db()
2565        .await
2566        .check_room_participants(room_id, leader_id, session.connection_id)
2567        .await?;
2568
2569    session
2570        .peer
2571        .forward_send(session.connection_id, leader_id, request)?;
2572
2573    if let Some(project_id) = project_id {
2574        let room = session
2575            .db()
2576            .await
2577            .unfollow(room_id, project_id, leader_id, follower_id)
2578            .await?;
2579        room_updated(&room, &session.peer);
2580    }
2581
2582    Ok(())
2583}
2584
2585/// Notify everyone following you of your current location.
2586async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2587    let room_id = RoomId::from_proto(request.room_id);
2588    let database = session.db.lock().await;
2589
2590    let connection_ids = if let Some(project_id) = request.project_id {
2591        let project_id = ProjectId::from_proto(project_id);
2592        database
2593            .project_connection_ids(project_id, session.connection_id, true)
2594            .await?
2595    } else {
2596        database
2597            .room_connection_ids(room_id, session.connection_id)
2598            .await?
2599    };
2600
2601    // For now, don't send view update messages back to that view's current leader.
2602    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2603        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2604        _ => None,
2605    });
2606
2607    for connection_id in connection_ids.iter().cloned() {
2608        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2609            session
2610                .peer
2611                .forward_send(session.connection_id, connection_id, request.clone())?;
2612        }
2613    }
2614    Ok(())
2615}
2616
2617/// Get public data about users.
2618async fn get_users(
2619    request: proto::GetUsers,
2620    response: Response<proto::GetUsers>,
2621    session: MessageContext,
2622) -> Result<()> {
2623    let user_ids = request
2624        .user_ids
2625        .into_iter()
2626        .map(UserId::from_proto)
2627        .collect();
2628    let users = session
2629        .db()
2630        .await
2631        .get_users_by_ids(user_ids)
2632        .await?
2633        .into_iter()
2634        .map(|user| proto::User {
2635            id: user.id.to_proto(),
2636            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2637            github_login: user.github_login,
2638            name: user.name,
2639        })
2640        .collect();
2641    response.send(proto::UsersResponse { users })?;
2642    Ok(())
2643}
2644
2645/// Search for users (to invite) buy Github login
2646async fn fuzzy_search_users(
2647    request: proto::FuzzySearchUsers,
2648    response: Response<proto::FuzzySearchUsers>,
2649    session: MessageContext,
2650) -> Result<()> {
2651    let query = request.query;
2652    let users = match query.len() {
2653        0 => vec![],
2654        1 | 2 => session
2655            .db()
2656            .await
2657            .get_user_by_github_login(&query)
2658            .await?
2659            .into_iter()
2660            .collect(),
2661        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2662    };
2663    let users = users
2664        .into_iter()
2665        .filter(|user| user.id != session.user_id())
2666        .map(|user| proto::User {
2667            id: user.id.to_proto(),
2668            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2669            github_login: user.github_login,
2670            name: user.name,
2671        })
2672        .collect();
2673    response.send(proto::UsersResponse { users })?;
2674    Ok(())
2675}
2676
2677/// Send a contact request to another user.
2678async fn request_contact(
2679    request: proto::RequestContact,
2680    response: Response<proto::RequestContact>,
2681    session: MessageContext,
2682) -> Result<()> {
2683    let requester_id = session.user_id();
2684    let responder_id = UserId::from_proto(request.responder_id);
2685    if requester_id == responder_id {
2686        return Err(anyhow!("cannot add yourself as a contact"))?;
2687    }
2688
2689    let notifications = session
2690        .db()
2691        .await
2692        .send_contact_request(requester_id, responder_id)
2693        .await?;
2694
2695    // Update outgoing contact requests of requester
2696    let mut update = proto::UpdateContacts::default();
2697    update.outgoing_requests.push(responder_id.to_proto());
2698    for connection_id in session
2699        .connection_pool()
2700        .await
2701        .user_connection_ids(requester_id)
2702    {
2703        session.peer.send(connection_id, update.clone())?;
2704    }
2705
2706    // Update incoming contact requests of responder
2707    let mut update = proto::UpdateContacts::default();
2708    update
2709        .incoming_requests
2710        .push(proto::IncomingContactRequest {
2711            requester_id: requester_id.to_proto(),
2712        });
2713    let connection_pool = session.connection_pool().await;
2714    for connection_id in connection_pool.user_connection_ids(responder_id) {
2715        session.peer.send(connection_id, update.clone())?;
2716    }
2717
2718    send_notifications(&connection_pool, &session.peer, notifications);
2719
2720    response.send(proto::Ack {})?;
2721    Ok(())
2722}
2723
2724/// Accept or decline a contact request
2725async fn respond_to_contact_request(
2726    request: proto::RespondToContactRequest,
2727    response: Response<proto::RespondToContactRequest>,
2728    session: MessageContext,
2729) -> Result<()> {
2730    let responder_id = session.user_id();
2731    let requester_id = UserId::from_proto(request.requester_id);
2732    let db = session.db().await;
2733    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2734        db.dismiss_contact_notification(responder_id, requester_id)
2735            .await?;
2736    } else {
2737        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2738
2739        let notifications = db
2740            .respond_to_contact_request(responder_id, requester_id, accept)
2741            .await?;
2742        let requester_busy = db.is_user_busy(requester_id).await?;
2743        let responder_busy = db.is_user_busy(responder_id).await?;
2744
2745        let pool = session.connection_pool().await;
2746        // Update responder with new contact
2747        let mut update = proto::UpdateContacts::default();
2748        if accept {
2749            update
2750                .contacts
2751                .push(contact_for_user(requester_id, requester_busy, &pool));
2752        }
2753        update
2754            .remove_incoming_requests
2755            .push(requester_id.to_proto());
2756        for connection_id in pool.user_connection_ids(responder_id) {
2757            session.peer.send(connection_id, update.clone())?;
2758        }
2759
2760        // Update requester with new contact
2761        let mut update = proto::UpdateContacts::default();
2762        if accept {
2763            update
2764                .contacts
2765                .push(contact_for_user(responder_id, responder_busy, &pool));
2766        }
2767        update
2768            .remove_outgoing_requests
2769            .push(responder_id.to_proto());
2770
2771        for connection_id in pool.user_connection_ids(requester_id) {
2772            session.peer.send(connection_id, update.clone())?;
2773        }
2774
2775        send_notifications(&pool, &session.peer, notifications);
2776    }
2777
2778    response.send(proto::Ack {})?;
2779    Ok(())
2780}
2781
2782/// Remove a contact.
2783async fn remove_contact(
2784    request: proto::RemoveContact,
2785    response: Response<proto::RemoveContact>,
2786    session: MessageContext,
2787) -> Result<()> {
2788    let requester_id = session.user_id();
2789    let responder_id = UserId::from_proto(request.user_id);
2790    let db = session.db().await;
2791    let (contact_accepted, deleted_notification_id) =
2792        db.remove_contact(requester_id, responder_id).await?;
2793
2794    let pool = session.connection_pool().await;
2795    // Update outgoing contact requests of requester
2796    let mut update = proto::UpdateContacts::default();
2797    if contact_accepted {
2798        update.remove_contacts.push(responder_id.to_proto());
2799    } else {
2800        update
2801            .remove_outgoing_requests
2802            .push(responder_id.to_proto());
2803    }
2804    for connection_id in pool.user_connection_ids(requester_id) {
2805        session.peer.send(connection_id, update.clone())?;
2806    }
2807
2808    // Update incoming contact requests of responder
2809    let mut update = proto::UpdateContacts::default();
2810    if contact_accepted {
2811        update.remove_contacts.push(requester_id.to_proto());
2812    } else {
2813        update
2814            .remove_incoming_requests
2815            .push(requester_id.to_proto());
2816    }
2817    for connection_id in pool.user_connection_ids(responder_id) {
2818        session.peer.send(connection_id, update.clone())?;
2819        if let Some(notification_id) = deleted_notification_id {
2820            session.peer.send(
2821                connection_id,
2822                proto::DeleteNotification {
2823                    notification_id: notification_id.to_proto(),
2824                },
2825            )?;
2826        }
2827    }
2828
2829    response.send(proto::Ack {})?;
2830    Ok(())
2831}
2832
2833fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2834    version.0.minor() < 139
2835}
2836
2837async fn subscribe_to_channels(
2838    _: proto::SubscribeToChannels,
2839    session: MessageContext,
2840) -> Result<()> {
2841    subscribe_user_to_channels(session.user_id(), &session).await?;
2842    Ok(())
2843}
2844
2845async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2846    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2847    let mut pool = session.connection_pool().await;
2848    for membership in &channels_for_user.channel_memberships {
2849        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2850    }
2851    session.peer.send(
2852        session.connection_id,
2853        build_update_user_channels(&channels_for_user),
2854    )?;
2855    session.peer.send(
2856        session.connection_id,
2857        build_channels_update(channels_for_user),
2858    )?;
2859    Ok(())
2860}
2861
2862/// Creates a new channel.
2863async fn create_channel(
2864    request: proto::CreateChannel,
2865    response: Response<proto::CreateChannel>,
2866    session: MessageContext,
2867) -> Result<()> {
2868    let db = session.db().await;
2869
2870    let parent_id = request.parent_id.map(ChannelId::from_proto);
2871    let (channel, membership) = db
2872        .create_channel(&request.name, parent_id, session.user_id())
2873        .await?;
2874
2875    let root_id = channel.root_id();
2876    let channel = Channel::from_model(channel);
2877
2878    response.send(proto::CreateChannelResponse {
2879        channel: Some(channel.to_proto()),
2880        parent_id: request.parent_id,
2881    })?;
2882
2883    let mut connection_pool = session.connection_pool().await;
2884    if let Some(membership) = membership {
2885        connection_pool.subscribe_to_channel(
2886            membership.user_id,
2887            membership.channel_id,
2888            membership.role,
2889        );
2890        let update = proto::UpdateUserChannels {
2891            channel_memberships: vec![proto::ChannelMembership {
2892                channel_id: membership.channel_id.to_proto(),
2893                role: membership.role.into(),
2894            }],
2895            ..Default::default()
2896        };
2897        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2898            session.peer.send(connection_id, update.clone())?;
2899        }
2900    }
2901
2902    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2903        if !role.can_see_channel(channel.visibility) {
2904            continue;
2905        }
2906
2907        let update = proto::UpdateChannels {
2908            channels: vec![channel.to_proto()],
2909            ..Default::default()
2910        };
2911        session.peer.send(connection_id, update.clone())?;
2912    }
2913
2914    Ok(())
2915}
2916
2917/// Delete a channel
2918async fn delete_channel(
2919    request: proto::DeleteChannel,
2920    response: Response<proto::DeleteChannel>,
2921    session: MessageContext,
2922) -> Result<()> {
2923    let db = session.db().await;
2924
2925    let channel_id = request.channel_id;
2926    let (root_channel, removed_channels) = db
2927        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2928        .await?;
2929    response.send(proto::Ack {})?;
2930
2931    // Notify members of removed channels
2932    let mut update = proto::UpdateChannels::default();
2933    update
2934        .delete_channels
2935        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2936
2937    let connection_pool = session.connection_pool().await;
2938    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2939        session.peer.send(connection_id, update.clone())?;
2940    }
2941
2942    Ok(())
2943}
2944
2945/// Invite someone to join a channel.
2946async fn invite_channel_member(
2947    request: proto::InviteChannelMember,
2948    response: Response<proto::InviteChannelMember>,
2949    session: MessageContext,
2950) -> Result<()> {
2951    let db = session.db().await;
2952    let channel_id = ChannelId::from_proto(request.channel_id);
2953    let invitee_id = UserId::from_proto(request.user_id);
2954    let InviteMemberResult {
2955        channel,
2956        notifications,
2957    } = db
2958        .invite_channel_member(
2959            channel_id,
2960            invitee_id,
2961            session.user_id(),
2962            request.role().into(),
2963        )
2964        .await?;
2965
2966    let update = proto::UpdateChannels {
2967        channel_invitations: vec![channel.to_proto()],
2968        ..Default::default()
2969    };
2970
2971    let connection_pool = session.connection_pool().await;
2972    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2973        session.peer.send(connection_id, update.clone())?;
2974    }
2975
2976    send_notifications(&connection_pool, &session.peer, notifications);
2977
2978    response.send(proto::Ack {})?;
2979    Ok(())
2980}
2981
2982/// remove someone from a channel
2983async fn remove_channel_member(
2984    request: proto::RemoveChannelMember,
2985    response: Response<proto::RemoveChannelMember>,
2986    session: MessageContext,
2987) -> Result<()> {
2988    let db = session.db().await;
2989    let channel_id = ChannelId::from_proto(request.channel_id);
2990    let member_id = UserId::from_proto(request.user_id);
2991
2992    let RemoveChannelMemberResult {
2993        membership_update,
2994        notification_id,
2995    } = db
2996        .remove_channel_member(channel_id, member_id, session.user_id())
2997        .await?;
2998
2999    let mut connection_pool = session.connection_pool().await;
3000    notify_membership_updated(
3001        &mut connection_pool,
3002        membership_update,
3003        member_id,
3004        &session.peer,
3005    );
3006    for connection_id in connection_pool.user_connection_ids(member_id) {
3007        if let Some(notification_id) = notification_id {
3008            session
3009                .peer
3010                .send(
3011                    connection_id,
3012                    proto::DeleteNotification {
3013                        notification_id: notification_id.to_proto(),
3014                    },
3015                )
3016                .trace_err();
3017        }
3018    }
3019
3020    response.send(proto::Ack {})?;
3021    Ok(())
3022}
3023
3024/// Toggle the channel between public and private.
3025/// Care is taken to maintain the invariant that public channels only descend from public channels,
3026/// (though members-only channels can appear at any point in the hierarchy).
3027async fn set_channel_visibility(
3028    request: proto::SetChannelVisibility,
3029    response: Response<proto::SetChannelVisibility>,
3030    session: MessageContext,
3031) -> Result<()> {
3032    let db = session.db().await;
3033    let channel_id = ChannelId::from_proto(request.channel_id);
3034    let visibility = request.visibility().into();
3035
3036    let channel_model = db
3037        .set_channel_visibility(channel_id, visibility, session.user_id())
3038        .await?;
3039    let root_id = channel_model.root_id();
3040    let channel = Channel::from_model(channel_model);
3041
3042    let mut connection_pool = session.connection_pool().await;
3043    for (user_id, role) in connection_pool
3044        .channel_user_ids(root_id)
3045        .collect::<Vec<_>>()
3046        .into_iter()
3047    {
3048        let update = if role.can_see_channel(channel.visibility) {
3049            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3050            proto::UpdateChannels {
3051                channels: vec![channel.to_proto()],
3052                ..Default::default()
3053            }
3054        } else {
3055            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3056            proto::UpdateChannels {
3057                delete_channels: vec![channel.id.to_proto()],
3058                ..Default::default()
3059            }
3060        };
3061
3062        for connection_id in connection_pool.user_connection_ids(user_id) {
3063            session.peer.send(connection_id, update.clone())?;
3064        }
3065    }
3066
3067    response.send(proto::Ack {})?;
3068    Ok(())
3069}
3070
3071/// Alter the role for a user in the channel.
3072async fn set_channel_member_role(
3073    request: proto::SetChannelMemberRole,
3074    response: Response<proto::SetChannelMemberRole>,
3075    session: MessageContext,
3076) -> Result<()> {
3077    let db = session.db().await;
3078    let channel_id = ChannelId::from_proto(request.channel_id);
3079    let member_id = UserId::from_proto(request.user_id);
3080    let result = db
3081        .set_channel_member_role(
3082            channel_id,
3083            session.user_id(),
3084            member_id,
3085            request.role().into(),
3086        )
3087        .await?;
3088
3089    match result {
3090        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3091            let mut connection_pool = session.connection_pool().await;
3092            notify_membership_updated(
3093                &mut connection_pool,
3094                membership_update,
3095                member_id,
3096                &session.peer,
3097            )
3098        }
3099        db::SetMemberRoleResult::InviteUpdated(channel) => {
3100            let update = proto::UpdateChannels {
3101                channel_invitations: vec![channel.to_proto()],
3102                ..Default::default()
3103            };
3104
3105            for connection_id in session
3106                .connection_pool()
3107                .await
3108                .user_connection_ids(member_id)
3109            {
3110                session.peer.send(connection_id, update.clone())?;
3111            }
3112        }
3113    }
3114
3115    response.send(proto::Ack {})?;
3116    Ok(())
3117}
3118
3119/// Change the name of a channel
3120async fn rename_channel(
3121    request: proto::RenameChannel,
3122    response: Response<proto::RenameChannel>,
3123    session: MessageContext,
3124) -> Result<()> {
3125    let db = session.db().await;
3126    let channel_id = ChannelId::from_proto(request.channel_id);
3127    let channel_model = db
3128        .rename_channel(channel_id, session.user_id(), &request.name)
3129        .await?;
3130    let root_id = channel_model.root_id();
3131    let channel = Channel::from_model(channel_model);
3132
3133    response.send(proto::RenameChannelResponse {
3134        channel: Some(channel.to_proto()),
3135    })?;
3136
3137    let connection_pool = session.connection_pool().await;
3138    let update = proto::UpdateChannels {
3139        channels: vec![channel.to_proto()],
3140        ..Default::default()
3141    };
3142    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3143        if role.can_see_channel(channel.visibility) {
3144            session.peer.send(connection_id, update.clone())?;
3145        }
3146    }
3147
3148    Ok(())
3149}
3150
3151/// Move a channel to a new parent.
3152async fn move_channel(
3153    request: proto::MoveChannel,
3154    response: Response<proto::MoveChannel>,
3155    session: MessageContext,
3156) -> Result<()> {
3157    let channel_id = ChannelId::from_proto(request.channel_id);
3158    let to = ChannelId::from_proto(request.to);
3159
3160    let (root_id, channels) = session
3161        .db()
3162        .await
3163        .move_channel(channel_id, to, session.user_id())
3164        .await?;
3165
3166    let connection_pool = session.connection_pool().await;
3167    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3168        let channels = channels
3169            .iter()
3170            .filter_map(|channel| {
3171                if role.can_see_channel(channel.visibility) {
3172                    Some(channel.to_proto())
3173                } else {
3174                    None
3175                }
3176            })
3177            .collect::<Vec<_>>();
3178        if channels.is_empty() {
3179            continue;
3180        }
3181
3182        let update = proto::UpdateChannels {
3183            channels,
3184            ..Default::default()
3185        };
3186
3187        session.peer.send(connection_id, update.clone())?;
3188    }
3189
3190    response.send(Ack {})?;
3191    Ok(())
3192}
3193
3194async fn reorder_channel(
3195    request: proto::ReorderChannel,
3196    response: Response<proto::ReorderChannel>,
3197    session: MessageContext,
3198) -> Result<()> {
3199    let channel_id = ChannelId::from_proto(request.channel_id);
3200    let direction = request.direction();
3201
3202    let updated_channels = session
3203        .db()
3204        .await
3205        .reorder_channel(channel_id, direction, session.user_id())
3206        .await?;
3207
3208    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3209        let connection_pool = session.connection_pool().await;
3210        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3211            let channels = updated_channels
3212                .iter()
3213                .filter_map(|channel| {
3214                    if role.can_see_channel(channel.visibility) {
3215                        Some(channel.to_proto())
3216                    } else {
3217                        None
3218                    }
3219                })
3220                .collect::<Vec<_>>();
3221
3222            if channels.is_empty() {
3223                continue;
3224            }
3225
3226            let update = proto::UpdateChannels {
3227                channels,
3228                ..Default::default()
3229            };
3230
3231            session.peer.send(connection_id, update.clone())?;
3232        }
3233    }
3234
3235    response.send(Ack {})?;
3236    Ok(())
3237}
3238
3239/// Get the list of channel members
3240async fn get_channel_members(
3241    request: proto::GetChannelMembers,
3242    response: Response<proto::GetChannelMembers>,
3243    session: MessageContext,
3244) -> Result<()> {
3245    let db = session.db().await;
3246    let channel_id = ChannelId::from_proto(request.channel_id);
3247    let limit = if request.limit == 0 {
3248        u16::MAX as u64
3249    } else {
3250        request.limit
3251    };
3252    let (members, users) = db
3253        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3254        .await?;
3255    response.send(proto::GetChannelMembersResponse { members, users })?;
3256    Ok(())
3257}
3258
3259/// Accept or decline a channel invitation.
3260async fn respond_to_channel_invite(
3261    request: proto::RespondToChannelInvite,
3262    response: Response<proto::RespondToChannelInvite>,
3263    session: MessageContext,
3264) -> Result<()> {
3265    let db = session.db().await;
3266    let channel_id = ChannelId::from_proto(request.channel_id);
3267    let RespondToChannelInvite {
3268        membership_update,
3269        notifications,
3270    } = db
3271        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3272        .await?;
3273
3274    let mut connection_pool = session.connection_pool().await;
3275    if let Some(membership_update) = membership_update {
3276        notify_membership_updated(
3277            &mut connection_pool,
3278            membership_update,
3279            session.user_id(),
3280            &session.peer,
3281        );
3282    } else {
3283        let update = proto::UpdateChannels {
3284            remove_channel_invitations: vec![channel_id.to_proto()],
3285            ..Default::default()
3286        };
3287
3288        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3289            session.peer.send(connection_id, update.clone())?;
3290        }
3291    };
3292
3293    send_notifications(&connection_pool, &session.peer, notifications);
3294
3295    response.send(proto::Ack {})?;
3296
3297    Ok(())
3298}
3299
3300/// Join the channels' room
3301async fn join_channel(
3302    request: proto::JoinChannel,
3303    response: Response<proto::JoinChannel>,
3304    session: MessageContext,
3305) -> Result<()> {
3306    let channel_id = ChannelId::from_proto(request.channel_id);
3307    join_channel_internal(channel_id, Box::new(response), session).await
3308}
3309
3310trait JoinChannelInternalResponse {
3311    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3312}
3313impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3314    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3315        Response::<proto::JoinChannel>::send(self, result)
3316    }
3317}
3318impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3319    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3320        Response::<proto::JoinRoom>::send(self, result)
3321    }
3322}
3323
3324async fn join_channel_internal(
3325    channel_id: ChannelId,
3326    response: Box<impl JoinChannelInternalResponse>,
3327    session: MessageContext,
3328) -> Result<()> {
3329    let joined_room = {
3330        let mut db = session.db().await;
3331        // If zed quits without leaving the room, and the user re-opens zed before the
3332        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3333        // room they were in.
3334        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3335            tracing::info!(
3336                stale_connection_id = %connection,
3337                "cleaning up stale connection",
3338            );
3339            drop(db);
3340            leave_room_for_session(&session, connection).await?;
3341            db = session.db().await;
3342        }
3343
3344        let (joined_room, membership_updated, role) = db
3345            .join_channel(channel_id, session.user_id(), session.connection_id)
3346            .await?;
3347
3348        let live_kit_connection_info =
3349            session
3350                .app_state
3351                .livekit_client
3352                .as_ref()
3353                .and_then(|live_kit| {
3354                    let (can_publish, token) = if role == ChannelRole::Guest {
3355                        (
3356                            false,
3357                            live_kit
3358                                .guest_token(
3359                                    &joined_room.room.livekit_room,
3360                                    &session.user_id().to_string(),
3361                                )
3362                                .trace_err()?,
3363                        )
3364                    } else {
3365                        (
3366                            true,
3367                            live_kit
3368                                .room_token(
3369                                    &joined_room.room.livekit_room,
3370                                    &session.user_id().to_string(),
3371                                )
3372                                .trace_err()?,
3373                        )
3374                    };
3375
3376                    Some(LiveKitConnectionInfo {
3377                        server_url: live_kit.url().into(),
3378                        token,
3379                        can_publish,
3380                    })
3381                });
3382
3383        response.send(proto::JoinRoomResponse {
3384            room: Some(joined_room.room.clone()),
3385            channel_id: joined_room
3386                .channel
3387                .as_ref()
3388                .map(|channel| channel.id.to_proto()),
3389            live_kit_connection_info,
3390        })?;
3391
3392        let mut connection_pool = session.connection_pool().await;
3393        if let Some(membership_updated) = membership_updated {
3394            notify_membership_updated(
3395                &mut connection_pool,
3396                membership_updated,
3397                session.user_id(),
3398                &session.peer,
3399            );
3400        }
3401
3402        room_updated(&joined_room.room, &session.peer);
3403
3404        joined_room
3405    };
3406
3407    channel_updated(
3408        &joined_room.channel.context("channel not returned")?,
3409        &joined_room.room,
3410        &session.peer,
3411        &*session.connection_pool().await,
3412    );
3413
3414    update_user_contacts(session.user_id(), &session).await?;
3415    Ok(())
3416}
3417
3418/// Start editing the channel notes
3419async fn join_channel_buffer(
3420    request: proto::JoinChannelBuffer,
3421    response: Response<proto::JoinChannelBuffer>,
3422    session: MessageContext,
3423) -> Result<()> {
3424    let db = session.db().await;
3425    let channel_id = ChannelId::from_proto(request.channel_id);
3426
3427    let open_response = db
3428        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3429        .await?;
3430
3431    let collaborators = open_response.collaborators.clone();
3432    response.send(open_response)?;
3433
3434    let update = UpdateChannelBufferCollaborators {
3435        channel_id: channel_id.to_proto(),
3436        collaborators: collaborators.clone(),
3437    };
3438    channel_buffer_updated(
3439        session.connection_id,
3440        collaborators
3441            .iter()
3442            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3443        &update,
3444        &session.peer,
3445    );
3446
3447    Ok(())
3448}
3449
3450/// Edit the channel notes
3451async fn update_channel_buffer(
3452    request: proto::UpdateChannelBuffer,
3453    session: MessageContext,
3454) -> Result<()> {
3455    let db = session.db().await;
3456    let channel_id = ChannelId::from_proto(request.channel_id);
3457
3458    let (collaborators, epoch, version) = db
3459        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3460        .await?;
3461
3462    channel_buffer_updated(
3463        session.connection_id,
3464        collaborators.clone(),
3465        &proto::UpdateChannelBuffer {
3466            channel_id: channel_id.to_proto(),
3467            operations: request.operations,
3468        },
3469        &session.peer,
3470    );
3471
3472    let pool = &*session.connection_pool().await;
3473
3474    let non_collaborators =
3475        pool.channel_connection_ids(channel_id)
3476            .filter_map(|(connection_id, _)| {
3477                if collaborators.contains(&connection_id) {
3478                    None
3479                } else {
3480                    Some(connection_id)
3481                }
3482            });
3483
3484    broadcast(None, non_collaborators, |peer_id| {
3485        session.peer.send(
3486            peer_id,
3487            proto::UpdateChannels {
3488                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3489                    channel_id: channel_id.to_proto(),
3490                    epoch: epoch as u64,
3491                    version: version.clone(),
3492                }],
3493                ..Default::default()
3494            },
3495        )
3496    });
3497
3498    Ok(())
3499}
3500
3501/// Rejoin the channel notes after a connection blip
3502async fn rejoin_channel_buffers(
3503    request: proto::RejoinChannelBuffers,
3504    response: Response<proto::RejoinChannelBuffers>,
3505    session: MessageContext,
3506) -> Result<()> {
3507    let db = session.db().await;
3508    let buffers = db
3509        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3510        .await?;
3511
3512    for rejoined_buffer in &buffers {
3513        let collaborators_to_notify = rejoined_buffer
3514            .buffer
3515            .collaborators
3516            .iter()
3517            .filter_map(|c| Some(c.peer_id?.into()));
3518        channel_buffer_updated(
3519            session.connection_id,
3520            collaborators_to_notify,
3521            &proto::UpdateChannelBufferCollaborators {
3522                channel_id: rejoined_buffer.buffer.channel_id,
3523                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3524            },
3525            &session.peer,
3526        );
3527    }
3528
3529    response.send(proto::RejoinChannelBuffersResponse {
3530        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3531    })?;
3532
3533    Ok(())
3534}
3535
3536/// Stop editing the channel notes
3537async fn leave_channel_buffer(
3538    request: proto::LeaveChannelBuffer,
3539    response: Response<proto::LeaveChannelBuffer>,
3540    session: MessageContext,
3541) -> Result<()> {
3542    let db = session.db().await;
3543    let channel_id = ChannelId::from_proto(request.channel_id);
3544
3545    let left_buffer = db
3546        .leave_channel_buffer(channel_id, session.connection_id)
3547        .await?;
3548
3549    response.send(Ack {})?;
3550
3551    channel_buffer_updated(
3552        session.connection_id,
3553        left_buffer.connections,
3554        &proto::UpdateChannelBufferCollaborators {
3555            channel_id: channel_id.to_proto(),
3556            collaborators: left_buffer.collaborators,
3557        },
3558        &session.peer,
3559    );
3560
3561    Ok(())
3562}
3563
3564fn channel_buffer_updated<T: EnvelopedMessage>(
3565    sender_id: ConnectionId,
3566    collaborators: impl IntoIterator<Item = ConnectionId>,
3567    message: &T,
3568    peer: &Peer,
3569) {
3570    broadcast(Some(sender_id), collaborators, |peer_id| {
3571        peer.send(peer_id, message.clone())
3572    });
3573}
3574
3575fn send_notifications(
3576    connection_pool: &ConnectionPool,
3577    peer: &Peer,
3578    notifications: db::NotificationBatch,
3579) {
3580    for (user_id, notification) in notifications {
3581        for connection_id in connection_pool.user_connection_ids(user_id) {
3582            if let Err(error) = peer.send(
3583                connection_id,
3584                proto::AddNotification {
3585                    notification: Some(notification.clone()),
3586                },
3587            ) {
3588                tracing::error!(
3589                    "failed to send notification to {:?} {}",
3590                    connection_id,
3591                    error
3592                );
3593            }
3594        }
3595    }
3596}
3597
3598/// Send a message to the channel
3599async fn send_channel_message(
3600    request: proto::SendChannelMessage,
3601    response: Response<proto::SendChannelMessage>,
3602    session: MessageContext,
3603) -> Result<()> {
3604    // Validate the message body.
3605    let body = request.body.trim().to_string();
3606    if body.len() > MAX_MESSAGE_LEN {
3607        return Err(anyhow!("message is too long"))?;
3608    }
3609    if body.is_empty() {
3610        return Err(anyhow!("message can't be blank"))?;
3611    }
3612
3613    // TODO: adjust mentions if body is trimmed
3614
3615    let timestamp = OffsetDateTime::now_utc();
3616    let nonce = request.nonce.context("nonce can't be blank")?;
3617
3618    let channel_id = ChannelId::from_proto(request.channel_id);
3619    let CreatedChannelMessage {
3620        message_id,
3621        participant_connection_ids,
3622        notifications,
3623    } = session
3624        .db()
3625        .await
3626        .create_channel_message(
3627            channel_id,
3628            session.user_id(),
3629            &body,
3630            &request.mentions,
3631            timestamp,
3632            nonce.clone().into(),
3633            request.reply_to_message_id.map(MessageId::from_proto),
3634        )
3635        .await?;
3636
3637    let message = proto::ChannelMessage {
3638        sender_id: session.user_id().to_proto(),
3639        id: message_id.to_proto(),
3640        body,
3641        mentions: request.mentions,
3642        timestamp: timestamp.unix_timestamp() as u64,
3643        nonce: Some(nonce),
3644        reply_to_message_id: request.reply_to_message_id,
3645        edited_at: None,
3646    };
3647    broadcast(
3648        Some(session.connection_id),
3649        participant_connection_ids.clone(),
3650        |connection| {
3651            session.peer.send(
3652                connection,
3653                proto::ChannelMessageSent {
3654                    channel_id: channel_id.to_proto(),
3655                    message: Some(message.clone()),
3656                },
3657            )
3658        },
3659    );
3660    response.send(proto::SendChannelMessageResponse {
3661        message: Some(message),
3662    })?;
3663
3664    let pool = &*session.connection_pool().await;
3665    let non_participants =
3666        pool.channel_connection_ids(channel_id)
3667            .filter_map(|(connection_id, _)| {
3668                if participant_connection_ids.contains(&connection_id) {
3669                    None
3670                } else {
3671                    Some(connection_id)
3672                }
3673            });
3674    broadcast(None, non_participants, |peer_id| {
3675        session.peer.send(
3676            peer_id,
3677            proto::UpdateChannels {
3678                latest_channel_message_ids: vec![proto::ChannelMessageId {
3679                    channel_id: channel_id.to_proto(),
3680                    message_id: message_id.to_proto(),
3681                }],
3682                ..Default::default()
3683            },
3684        )
3685    });
3686    send_notifications(pool, &session.peer, notifications);
3687
3688    Ok(())
3689}
3690
3691/// Delete a channel message
3692async fn remove_channel_message(
3693    request: proto::RemoveChannelMessage,
3694    response: Response<proto::RemoveChannelMessage>,
3695    session: MessageContext,
3696) -> Result<()> {
3697    let channel_id = ChannelId::from_proto(request.channel_id);
3698    let message_id = MessageId::from_proto(request.message_id);
3699    let (connection_ids, existing_notification_ids) = session
3700        .db()
3701        .await
3702        .remove_channel_message(channel_id, message_id, session.user_id())
3703        .await?;
3704
3705    broadcast(
3706        Some(session.connection_id),
3707        connection_ids,
3708        move |connection| {
3709            session.peer.send(connection, request.clone())?;
3710
3711            for notification_id in &existing_notification_ids {
3712                session.peer.send(
3713                    connection,
3714                    proto::DeleteNotification {
3715                        notification_id: (*notification_id).to_proto(),
3716                    },
3717                )?;
3718            }
3719
3720            Ok(())
3721        },
3722    );
3723    response.send(proto::Ack {})?;
3724    Ok(())
3725}
3726
3727async fn update_channel_message(
3728    request: proto::UpdateChannelMessage,
3729    response: Response<proto::UpdateChannelMessage>,
3730    session: MessageContext,
3731) -> Result<()> {
3732    let channel_id = ChannelId::from_proto(request.channel_id);
3733    let message_id = MessageId::from_proto(request.message_id);
3734    let updated_at = OffsetDateTime::now_utc();
3735    let UpdatedChannelMessage {
3736        message_id,
3737        participant_connection_ids,
3738        notifications,
3739        reply_to_message_id,
3740        timestamp,
3741        deleted_mention_notification_ids,
3742        updated_mention_notifications,
3743    } = session
3744        .db()
3745        .await
3746        .update_channel_message(
3747            channel_id,
3748            message_id,
3749            session.user_id(),
3750            request.body.as_str(),
3751            &request.mentions,
3752            updated_at,
3753        )
3754        .await?;
3755
3756    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3757
3758    let message = proto::ChannelMessage {
3759        sender_id: session.user_id().to_proto(),
3760        id: message_id.to_proto(),
3761        body: request.body.clone(),
3762        mentions: request.mentions.clone(),
3763        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3764        nonce: Some(nonce),
3765        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3766        edited_at: Some(updated_at.unix_timestamp() as u64),
3767    };
3768
3769    response.send(proto::Ack {})?;
3770
3771    let pool = &*session.connection_pool().await;
3772    broadcast(
3773        Some(session.connection_id),
3774        participant_connection_ids,
3775        |connection| {
3776            session.peer.send(
3777                connection,
3778                proto::ChannelMessageUpdate {
3779                    channel_id: channel_id.to_proto(),
3780                    message: Some(message.clone()),
3781                },
3782            )?;
3783
3784            for notification_id in &deleted_mention_notification_ids {
3785                session.peer.send(
3786                    connection,
3787                    proto::DeleteNotification {
3788                        notification_id: (*notification_id).to_proto(),
3789                    },
3790                )?;
3791            }
3792
3793            for notification in &updated_mention_notifications {
3794                session.peer.send(
3795                    connection,
3796                    proto::UpdateNotification {
3797                        notification: Some(notification.clone()),
3798                    },
3799                )?;
3800            }
3801
3802            Ok(())
3803        },
3804    );
3805
3806    send_notifications(pool, &session.peer, notifications);
3807
3808    Ok(())
3809}
3810
3811/// Mark a channel message as read
3812async fn acknowledge_channel_message(
3813    request: proto::AckChannelMessage,
3814    session: MessageContext,
3815) -> Result<()> {
3816    let channel_id = ChannelId::from_proto(request.channel_id);
3817    let message_id = MessageId::from_proto(request.message_id);
3818    let notifications = session
3819        .db()
3820        .await
3821        .observe_channel_message(channel_id, session.user_id(), message_id)
3822        .await?;
3823    send_notifications(
3824        &*session.connection_pool().await,
3825        &session.peer,
3826        notifications,
3827    );
3828    Ok(())
3829}
3830
3831/// Mark a buffer version as synced
3832async fn acknowledge_buffer_version(
3833    request: proto::AckBufferOperation,
3834    session: MessageContext,
3835) -> Result<()> {
3836    let buffer_id = BufferId::from_proto(request.buffer_id);
3837    session
3838        .db()
3839        .await
3840        .observe_buffer_version(
3841            buffer_id,
3842            session.user_id(),
3843            request.epoch as i32,
3844            &request.version,
3845        )
3846        .await?;
3847    Ok(())
3848}
3849
3850/// Get a Supermaven API key for the user
3851async fn get_supermaven_api_key(
3852    _request: proto::GetSupermavenApiKey,
3853    response: Response<proto::GetSupermavenApiKey>,
3854    session: MessageContext,
3855) -> Result<()> {
3856    let user_id: String = session.user_id().to_string();
3857    if !session.is_staff() {
3858        return Err(anyhow!("supermaven not enabled for this account"))?;
3859    }
3860
3861    let email = session.email().context("user must have an email")?;
3862
3863    let supermaven_admin_api = session
3864        .supermaven_client
3865        .as_ref()
3866        .context("supermaven not configured")?;
3867
3868    let result = supermaven_admin_api
3869        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3870        .await?;
3871
3872    response.send(proto::GetSupermavenApiKeyResponse {
3873        api_key: result.api_key,
3874    })?;
3875
3876    Ok(())
3877}
3878
3879/// Start receiving chat updates for a channel
3880async fn join_channel_chat(
3881    request: proto::JoinChannelChat,
3882    response: Response<proto::JoinChannelChat>,
3883    session: MessageContext,
3884) -> Result<()> {
3885    let channel_id = ChannelId::from_proto(request.channel_id);
3886
3887    let db = session.db().await;
3888    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3889        .await?;
3890    let messages = db
3891        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3892        .await?;
3893    response.send(proto::JoinChannelChatResponse {
3894        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3895        messages,
3896    })?;
3897    Ok(())
3898}
3899
3900/// Stop receiving chat updates for a channel
3901async fn leave_channel_chat(
3902    request: proto::LeaveChannelChat,
3903    session: MessageContext,
3904) -> Result<()> {
3905    let channel_id = ChannelId::from_proto(request.channel_id);
3906    session
3907        .db()
3908        .await
3909        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3910        .await?;
3911    Ok(())
3912}
3913
3914/// Retrieve the chat history for a channel
3915async fn get_channel_messages(
3916    request: proto::GetChannelMessages,
3917    response: Response<proto::GetChannelMessages>,
3918    session: MessageContext,
3919) -> Result<()> {
3920    let channel_id = ChannelId::from_proto(request.channel_id);
3921    let messages = session
3922        .db()
3923        .await
3924        .get_channel_messages(
3925            channel_id,
3926            session.user_id(),
3927            MESSAGE_COUNT_PER_PAGE,
3928            Some(MessageId::from_proto(request.before_message_id)),
3929        )
3930        .await?;
3931    response.send(proto::GetChannelMessagesResponse {
3932        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3933        messages,
3934    })?;
3935    Ok(())
3936}
3937
3938/// Retrieve specific chat messages
3939async fn get_channel_messages_by_id(
3940    request: proto::GetChannelMessagesById,
3941    response: Response<proto::GetChannelMessagesById>,
3942    session: MessageContext,
3943) -> Result<()> {
3944    let message_ids = request
3945        .message_ids
3946        .iter()
3947        .map(|id| MessageId::from_proto(*id))
3948        .collect::<Vec<_>>();
3949    let messages = session
3950        .db()
3951        .await
3952        .get_channel_messages_by_id(session.user_id(), &message_ids)
3953        .await?;
3954    response.send(proto::GetChannelMessagesResponse {
3955        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3956        messages,
3957    })?;
3958    Ok(())
3959}
3960
3961/// Retrieve the current users notifications
3962async fn get_notifications(
3963    request: proto::GetNotifications,
3964    response: Response<proto::GetNotifications>,
3965    session: MessageContext,
3966) -> Result<()> {
3967    let notifications = session
3968        .db()
3969        .await
3970        .get_notifications(
3971            session.user_id(),
3972            NOTIFICATION_COUNT_PER_PAGE,
3973            request.before_id.map(db::NotificationId::from_proto),
3974        )
3975        .await?;
3976    response.send(proto::GetNotificationsResponse {
3977        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3978        notifications,
3979    })?;
3980    Ok(())
3981}
3982
3983/// Mark notifications as read
3984async fn mark_notification_as_read(
3985    request: proto::MarkNotificationRead,
3986    response: Response<proto::MarkNotificationRead>,
3987    session: MessageContext,
3988) -> Result<()> {
3989    let database = &session.db().await;
3990    let notifications = database
3991        .mark_notification_as_read_by_id(
3992            session.user_id(),
3993            NotificationId::from_proto(request.notification_id),
3994        )
3995        .await?;
3996    send_notifications(
3997        &*session.connection_pool().await,
3998        &session.peer,
3999        notifications,
4000    );
4001    response.send(proto::Ack {})?;
4002    Ok(())
4003}
4004
4005fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4006    let message = match message {
4007        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4008        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4009        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4010        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4011        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4012            code: frame.code.into(),
4013            reason: frame.reason.as_str().to_owned().into(),
4014        })),
4015        // We should never receive a frame while reading the message, according
4016        // to the `tungstenite` maintainers:
4017        //
4018        // > It cannot occur when you read messages from the WebSocket, but it
4019        // > can be used when you want to send the raw frames (e.g. you want to
4020        // > send the frames to the WebSocket without composing the full message first).
4021        // >
4022        // > — https://github.com/snapview/tungstenite-rs/issues/268
4023        TungsteniteMessage::Frame(_) => {
4024            bail!("received an unexpected frame while reading the message")
4025        }
4026    };
4027
4028    Ok(message)
4029}
4030
4031fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4032    match message {
4033        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4034        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4035        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4036        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4037        AxumMessage::Close(frame) => {
4038            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4039                code: frame.code.into(),
4040                reason: frame.reason.as_ref().into(),
4041            }))
4042        }
4043    }
4044}
4045
4046fn notify_membership_updated(
4047    connection_pool: &mut ConnectionPool,
4048    result: MembershipUpdated,
4049    user_id: UserId,
4050    peer: &Peer,
4051) {
4052    for membership in &result.new_channels.channel_memberships {
4053        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4054    }
4055    for channel_id in &result.removed_channels {
4056        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4057    }
4058
4059    let user_channels_update = proto::UpdateUserChannels {
4060        channel_memberships: result
4061            .new_channels
4062            .channel_memberships
4063            .iter()
4064            .map(|cm| proto::ChannelMembership {
4065                channel_id: cm.channel_id.to_proto(),
4066                role: cm.role.into(),
4067            })
4068            .collect(),
4069        ..Default::default()
4070    };
4071
4072    let mut update = build_channels_update(result.new_channels);
4073    update.delete_channels = result
4074        .removed_channels
4075        .into_iter()
4076        .map(|id| id.to_proto())
4077        .collect();
4078    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4079
4080    for connection_id in connection_pool.user_connection_ids(user_id) {
4081        peer.send(connection_id, user_channels_update.clone())
4082            .trace_err();
4083        peer.send(connection_id, update.clone()).trace_err();
4084    }
4085}
4086
4087fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4088    proto::UpdateUserChannels {
4089        channel_memberships: channels
4090            .channel_memberships
4091            .iter()
4092            .map(|m| proto::ChannelMembership {
4093                channel_id: m.channel_id.to_proto(),
4094                role: m.role.into(),
4095            })
4096            .collect(),
4097        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4098        observed_channel_message_id: channels.observed_channel_messages.clone(),
4099    }
4100}
4101
4102fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4103    let mut update = proto::UpdateChannels::default();
4104
4105    for channel in channels.channels {
4106        update.channels.push(channel.to_proto());
4107    }
4108
4109    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4110    update.latest_channel_message_ids = channels.latest_channel_messages;
4111
4112    for (channel_id, participants) in channels.channel_participants {
4113        update
4114            .channel_participants
4115            .push(proto::ChannelParticipants {
4116                channel_id: channel_id.to_proto(),
4117                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4118            });
4119    }
4120
4121    for channel in channels.invited_channels {
4122        update.channel_invitations.push(channel.to_proto());
4123    }
4124
4125    update
4126}
4127
4128fn build_initial_contacts_update(
4129    contacts: Vec<db::Contact>,
4130    pool: &ConnectionPool,
4131) -> proto::UpdateContacts {
4132    let mut update = proto::UpdateContacts::default();
4133
4134    for contact in contacts {
4135        match contact {
4136            db::Contact::Accepted { user_id, busy } => {
4137                update.contacts.push(contact_for_user(user_id, busy, pool));
4138            }
4139            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4140            db::Contact::Incoming { user_id } => {
4141                update
4142                    .incoming_requests
4143                    .push(proto::IncomingContactRequest {
4144                        requester_id: user_id.to_proto(),
4145                    })
4146            }
4147        }
4148    }
4149
4150    update
4151}
4152
4153fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4154    proto::Contact {
4155        user_id: user_id.to_proto(),
4156        online: pool.is_user_online(user_id),
4157        busy,
4158    }
4159}
4160
4161fn room_updated(room: &proto::Room, peer: &Peer) {
4162    broadcast(
4163        None,
4164        room.participants
4165            .iter()
4166            .filter_map(|participant| Some(participant.peer_id?.into())),
4167        |peer_id| {
4168            peer.send(
4169                peer_id,
4170                proto::RoomUpdated {
4171                    room: Some(room.clone()),
4172                },
4173            )
4174        },
4175    );
4176}
4177
4178fn channel_updated(
4179    channel: &db::channel::Model,
4180    room: &proto::Room,
4181    peer: &Peer,
4182    pool: &ConnectionPool,
4183) {
4184    let participants = room
4185        .participants
4186        .iter()
4187        .map(|p| p.user_id)
4188        .collect::<Vec<_>>();
4189
4190    broadcast(
4191        None,
4192        pool.channel_connection_ids(channel.root_id())
4193            .filter_map(|(channel_id, role)| {
4194                role.can_see_channel(channel.visibility)
4195                    .then_some(channel_id)
4196            }),
4197        |peer_id| {
4198            peer.send(
4199                peer_id,
4200                proto::UpdateChannels {
4201                    channel_participants: vec![proto::ChannelParticipants {
4202                        channel_id: channel.id.to_proto(),
4203                        participant_user_ids: participants.clone(),
4204                    }],
4205                    ..Default::default()
4206                },
4207            )
4208        },
4209    );
4210}
4211
4212async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4213    let db = session.db().await;
4214
4215    let contacts = db.get_contacts(user_id).await?;
4216    let busy = db.is_user_busy(user_id).await?;
4217
4218    let pool = session.connection_pool().await;
4219    let updated_contact = contact_for_user(user_id, busy, &pool);
4220    for contact in contacts {
4221        if let db::Contact::Accepted {
4222            user_id: contact_user_id,
4223            ..
4224        } = contact
4225        {
4226            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4227                session
4228                    .peer
4229                    .send(
4230                        contact_conn_id,
4231                        proto::UpdateContacts {
4232                            contacts: vec![updated_contact.clone()],
4233                            remove_contacts: Default::default(),
4234                            incoming_requests: Default::default(),
4235                            remove_incoming_requests: Default::default(),
4236                            outgoing_requests: Default::default(),
4237                            remove_outgoing_requests: Default::default(),
4238                        },
4239                    )
4240                    .trace_err();
4241            }
4242        }
4243    }
4244    Ok(())
4245}
4246
4247async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4248    let mut contacts_to_update = HashSet::default();
4249
4250    let room_id;
4251    let canceled_calls_to_user_ids;
4252    let livekit_room;
4253    let delete_livekit_room;
4254    let room;
4255    let channel;
4256
4257    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4258        contacts_to_update.insert(session.user_id());
4259
4260        for project in left_room.left_projects.values() {
4261            project_left(project, session);
4262        }
4263
4264        room_id = RoomId::from_proto(left_room.room.id);
4265        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4266        livekit_room = mem::take(&mut left_room.room.livekit_room);
4267        delete_livekit_room = left_room.deleted;
4268        room = mem::take(&mut left_room.room);
4269        channel = mem::take(&mut left_room.channel);
4270
4271        room_updated(&room, &session.peer);
4272    } else {
4273        return Ok(());
4274    }
4275
4276    if let Some(channel) = channel {
4277        channel_updated(
4278            &channel,
4279            &room,
4280            &session.peer,
4281            &*session.connection_pool().await,
4282        );
4283    }
4284
4285    {
4286        let pool = session.connection_pool().await;
4287        for canceled_user_id in canceled_calls_to_user_ids {
4288            for connection_id in pool.user_connection_ids(canceled_user_id) {
4289                session
4290                    .peer
4291                    .send(
4292                        connection_id,
4293                        proto::CallCanceled {
4294                            room_id: room_id.to_proto(),
4295                        },
4296                    )
4297                    .trace_err();
4298            }
4299            contacts_to_update.insert(canceled_user_id);
4300        }
4301    }
4302
4303    for contact_user_id in contacts_to_update {
4304        update_user_contacts(contact_user_id, session).await?;
4305    }
4306
4307    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4308        live_kit
4309            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4310            .await
4311            .trace_err();
4312
4313        if delete_livekit_room {
4314            live_kit.delete_room(livekit_room).await.trace_err();
4315        }
4316    }
4317
4318    Ok(())
4319}
4320
4321async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4322    let left_channel_buffers = session
4323        .db()
4324        .await
4325        .leave_channel_buffers(session.connection_id)
4326        .await?;
4327
4328    for left_buffer in left_channel_buffers {
4329        channel_buffer_updated(
4330            session.connection_id,
4331            left_buffer.connections,
4332            &proto::UpdateChannelBufferCollaborators {
4333                channel_id: left_buffer.channel_id.to_proto(),
4334                collaborators: left_buffer.collaborators,
4335            },
4336            &session.peer,
4337        );
4338    }
4339
4340    Ok(())
4341}
4342
4343fn project_left(project: &db::LeftProject, session: &Session) {
4344    for connection_id in &project.connection_ids {
4345        if project.should_unshare {
4346            session
4347                .peer
4348                .send(
4349                    *connection_id,
4350                    proto::UnshareProject {
4351                        project_id: project.id.to_proto(),
4352                    },
4353                )
4354                .trace_err();
4355        } else {
4356            session
4357                .peer
4358                .send(
4359                    *connection_id,
4360                    proto::RemoveProjectCollaborator {
4361                        project_id: project.id.to_proto(),
4362                        peer_id: Some(session.connection_id.into()),
4363                    },
4364                )
4365                .trace_err();
4366        }
4367    }
4368}
4369
4370pub trait ResultExt {
4371    type Ok;
4372
4373    fn trace_err(self) -> Option<Self::Ok>;
4374}
4375
4376impl<T, E> ResultExt for Result<T, E>
4377where
4378    E: std::fmt::Debug,
4379{
4380    type Ok = T;
4381
4382    #[track_caller]
4383    fn trace_err(self) -> Option<T> {
4384        match self {
4385            Ok(value) => Some(value),
4386            Err(error) => {
4387                tracing::error!("{:?}", error);
4388                None
4389            }
4390        }
4391    }
4392}