rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   8        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   9        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  10        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use reqwest_client::ReqwestClient;
  37use rpc::proto::{MultiLspQuery, split_repository_update};
  38use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  39use tracing::Span;
  40
  41use futures::{
  42    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  43    stream::FuturesUnordered,
  44};
  45use prometheus::{IntGauge, register_int_gauge};
  46use rpc::{
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52};
  53use semantic_version::SemanticVersion;
  54use serde::{Serialize, Serializer};
  55use std::{
  56    any::TypeId,
  57    future::Future,
  58    marker::PhantomData,
  59    mem,
  60    net::SocketAddr,
  61    ops::{Deref, DerefMut},
  62    rc::Rc,
  63    sync::{
  64        Arc, OnceLock,
  65        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  66    },
  67    time::{Duration, Instant},
  68};
  69use time::OffsetDateTime;
  70use tokio::sync::{Semaphore, watch};
  71use tower::ServiceBuilder;
  72use tracing::{
  73    Instrument,
  74    field::{self},
  75    info_span, instrument,
  76};
  77
  78pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  79
  80// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  81pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  82
  83const MESSAGE_COUNT_PER_PAGE: usize = 100;
  84const MAX_MESSAGE_LEN: usize = 1024;
  85const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  86const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  87
  88static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  89
  90const TOTAL_DURATION_MS: &str = "total_duration_ms";
  91const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  92const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  93const HOST_WAITING_MS: &str = "host_waiting_ms";
  94
  95type MessageHandler =
  96    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  97
  98pub struct ConnectionGuard;
  99
 100impl ConnectionGuard {
 101    pub fn try_acquire() -> Result<Self, ()> {
 102        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 103        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 104            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 105            tracing::error!(
 106                "too many concurrent connections: {}",
 107                current_connections + 1
 108            );
 109            return Err(());
 110        }
 111        Ok(ConnectionGuard)
 112    }
 113}
 114
 115impl Drop for ConnectionGuard {
 116    fn drop(&mut self) {
 117        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 118    }
 119}
 120
 121struct Response<R> {
 122    peer: Arc<Peer>,
 123    receipt: Receipt<R>,
 124    responded: Arc<AtomicBool>,
 125}
 126
 127impl<R: RequestMessage> Response<R> {
 128    fn send(self, payload: R::Response) -> Result<()> {
 129        self.responded.store(true, SeqCst);
 130        self.peer.respond(self.receipt, payload)?;
 131        Ok(())
 132    }
 133}
 134
 135#[derive(Clone, Debug)]
 136pub enum Principal {
 137    User(User),
 138    Impersonated { user: User, admin: User },
 139}
 140
 141impl Principal {
 142    fn update_span(&self, span: &tracing::Span) {
 143        match &self {
 144            Principal::User(user) => {
 145                span.record("user_id", user.id.0);
 146                span.record("login", &user.github_login);
 147            }
 148            Principal::Impersonated { user, admin } => {
 149                span.record("user_id", user.id.0);
 150                span.record("login", &user.github_login);
 151                span.record("impersonator", &admin.github_login);
 152            }
 153        }
 154    }
 155}
 156
 157#[derive(Clone)]
 158struct MessageContext {
 159    session: Session,
 160    span: tracing::Span,
 161}
 162
 163impl Deref for MessageContext {
 164    type Target = Session;
 165
 166    fn deref(&self) -> &Self::Target {
 167        &self.session
 168    }
 169}
 170
 171impl MessageContext {
 172    pub fn forward_request<T: RequestMessage>(
 173        &self,
 174        receiver_id: ConnectionId,
 175        request: T,
 176    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 177        let request_start_time = Instant::now();
 178        let span = self.span.clone();
 179        tracing::info!("start forwarding request");
 180        self.peer
 181            .forward_request(self.connection_id, receiver_id, request)
 182            .inspect(move |_| {
 183                span.record(
 184                    HOST_WAITING_MS,
 185                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 186                );
 187            })
 188            .inspect_err(|_| tracing::error!("error forwarding request"))
 189            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 190    }
 191}
 192
 193#[derive(Clone)]
 194struct Session {
 195    principal: Principal,
 196    connection_id: ConnectionId,
 197    db: Arc<tokio::sync::Mutex<DbHandle>>,
 198    peer: Arc<Peer>,
 199    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 200    app_state: Arc<AppState>,
 201    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 202    /// The GeoIP country code for the user.
 203    #[allow(unused)]
 204    geoip_country_code: Option<String>,
 205    #[allow(unused)]
 206    system_id: Option<String>,
 207    _executor: Executor,
 208}
 209
 210impl Session {
 211    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 212        #[cfg(test)]
 213        tokio::task::yield_now().await;
 214        let guard = self.db.lock().await;
 215        #[cfg(test)]
 216        tokio::task::yield_now().await;
 217        guard
 218    }
 219
 220    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 221        #[cfg(test)]
 222        tokio::task::yield_now().await;
 223        let guard = self.connection_pool.lock();
 224        ConnectionPoolGuard {
 225            guard,
 226            _not_send: PhantomData,
 227        }
 228    }
 229
 230    fn is_staff(&self) -> bool {
 231        match &self.principal {
 232            Principal::User(user) => user.admin,
 233            Principal::Impersonated { .. } => true,
 234        }
 235    }
 236
 237    fn user_id(&self) -> UserId {
 238        match &self.principal {
 239            Principal::User(user) => user.id,
 240            Principal::Impersonated { user, .. } => user.id,
 241        }
 242    }
 243
 244    pub fn email(&self) -> Option<String> {
 245        match &self.principal {
 246            Principal::User(user) => user.email_address.clone(),
 247            Principal::Impersonated { user, .. } => user.email_address.clone(),
 248        }
 249    }
 250}
 251
 252impl Debug for Session {
 253    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 254        let mut result = f.debug_struct("Session");
 255        match &self.principal {
 256            Principal::User(user) => {
 257                result.field("user", &user.github_login);
 258            }
 259            Principal::Impersonated { user, admin } => {
 260                result.field("user", &user.github_login);
 261                result.field("impersonator", &admin.github_login);
 262            }
 263        }
 264        result.field("connection_id", &self.connection_id).finish()
 265    }
 266}
 267
 268struct DbHandle(Arc<Database>);
 269
 270impl Deref for DbHandle {
 271    type Target = Database;
 272
 273    fn deref(&self) -> &Self::Target {
 274        self.0.as_ref()
 275    }
 276}
 277
 278pub struct Server {
 279    id: parking_lot::Mutex<ServerId>,
 280    peer: Arc<Peer>,
 281    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 282    app_state: Arc<AppState>,
 283    handlers: HashMap<TypeId, MessageHandler>,
 284    teardown: watch::Sender<bool>,
 285}
 286
 287pub(crate) struct ConnectionPoolGuard<'a> {
 288    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 289    _not_send: PhantomData<Rc<()>>,
 290}
 291
 292#[derive(Serialize)]
 293pub struct ServerSnapshot<'a> {
 294    peer: &'a Peer,
 295    #[serde(serialize_with = "serialize_deref")]
 296    connection_pool: ConnectionPoolGuard<'a>,
 297}
 298
 299pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 300where
 301    S: Serializer,
 302    T: Deref<Target = U>,
 303    U: Serialize,
 304{
 305    Serialize::serialize(value.deref(), serializer)
 306}
 307
 308impl Server {
 309    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 310        let mut server = Self {
 311            id: parking_lot::Mutex::new(id),
 312            peer: Peer::new(id.0 as u32),
 313            app_state: app_state.clone(),
 314            connection_pool: Default::default(),
 315            handlers: Default::default(),
 316            teardown: watch::channel(false).0,
 317        };
 318
 319        server
 320            .add_request_handler(ping)
 321            .add_request_handler(create_room)
 322            .add_request_handler(join_room)
 323            .add_request_handler(rejoin_room)
 324            .add_request_handler(leave_room)
 325            .add_request_handler(set_room_participant_role)
 326            .add_request_handler(call)
 327            .add_request_handler(cancel_call)
 328            .add_message_handler(decline_call)
 329            .add_request_handler(update_participant_location)
 330            .add_request_handler(share_project)
 331            .add_message_handler(unshare_project)
 332            .add_request_handler(join_project)
 333            .add_message_handler(leave_project)
 334            .add_request_handler(update_project)
 335            .add_request_handler(update_worktree)
 336            .add_request_handler(update_repository)
 337            .add_request_handler(remove_repository)
 338            .add_message_handler(start_language_server)
 339            .add_message_handler(update_language_server)
 340            .add_message_handler(update_diagnostic_summary)
 341            .add_message_handler(update_worktree_settings)
 342            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 344            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 345            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 346            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 347            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 348            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 349            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 350            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 351            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 352            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 353            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 354            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 355            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 356            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 357            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 358            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 359            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 360            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 361            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 362            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 363            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 364            .add_request_handler(
 365                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 366            )
 367            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 368            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 369            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 370            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 371            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 372            .add_request_handler(
 373                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 374            )
 375            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 376            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 377            .add_request_handler(
 378                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 379            )
 380            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 381            .add_request_handler(
 382                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 383            )
 384            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 385            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 386            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 387            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 388            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 389            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 390            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 391            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 392            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 393            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 394            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 395            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 396            .add_request_handler(
 397                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 398            )
 399            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 400            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 401            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 402            .add_request_handler(multi_lsp_query)
 403            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 404            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 405            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 406            .add_message_handler(create_buffer_for_peer)
 407            .add_request_handler(update_buffer)
 408            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 409            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 410            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 411            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 412            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 413            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 414            .add_message_handler(
 415                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 416            )
 417            .add_request_handler(get_users)
 418            .add_request_handler(fuzzy_search_users)
 419            .add_request_handler(request_contact)
 420            .add_request_handler(remove_contact)
 421            .add_request_handler(respond_to_contact_request)
 422            .add_message_handler(subscribe_to_channels)
 423            .add_request_handler(create_channel)
 424            .add_request_handler(delete_channel)
 425            .add_request_handler(invite_channel_member)
 426            .add_request_handler(remove_channel_member)
 427            .add_request_handler(set_channel_member_role)
 428            .add_request_handler(set_channel_visibility)
 429            .add_request_handler(rename_channel)
 430            .add_request_handler(join_channel_buffer)
 431            .add_request_handler(leave_channel_buffer)
 432            .add_message_handler(update_channel_buffer)
 433            .add_request_handler(rejoin_channel_buffers)
 434            .add_request_handler(get_channel_members)
 435            .add_request_handler(respond_to_channel_invite)
 436            .add_request_handler(join_channel)
 437            .add_request_handler(join_channel_chat)
 438            .add_message_handler(leave_channel_chat)
 439            .add_request_handler(send_channel_message)
 440            .add_request_handler(remove_channel_message)
 441            .add_request_handler(update_channel_message)
 442            .add_request_handler(get_channel_messages)
 443            .add_request_handler(get_channel_messages_by_id)
 444            .add_request_handler(get_notifications)
 445            .add_request_handler(mark_notification_as_read)
 446            .add_request_handler(move_channel)
 447            .add_request_handler(reorder_channel)
 448            .add_request_handler(follow)
 449            .add_message_handler(unfollow)
 450            .add_message_handler(update_followers)
 451            .add_message_handler(acknowledge_channel_message)
 452            .add_message_handler(acknowledge_buffer_version)
 453            .add_request_handler(get_supermaven_api_key)
 454            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 455            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 456            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 457            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 458            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 459            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 460            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 461            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 462            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 463            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 464            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 465            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 466            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 467            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 468            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 469            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 470            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 471            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 472            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 473            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 474            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 475            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 476            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 477            .add_message_handler(update_context);
 478
 479        Arc::new(server)
 480    }
 481
 482    pub async fn start(&self) -> Result<()> {
 483        let server_id = *self.id.lock();
 484        let app_state = self.app_state.clone();
 485        let peer = self.peer.clone();
 486        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 487        let pool = self.connection_pool.clone();
 488        let livekit_client = self.app_state.livekit_client.clone();
 489
 490        let span = info_span!("start server");
 491        self.app_state.executor.spawn_detached(
 492            async move {
 493                tracing::info!("waiting for cleanup timeout");
 494                timeout.await;
 495                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 496
 497                app_state
 498                    .db
 499                    .delete_stale_channel_chat_participants(
 500                        &app_state.config.zed_environment,
 501                        server_id,
 502                    )
 503                    .await
 504                    .trace_err();
 505
 506                if let Some((room_ids, channel_ids)) = app_state
 507                    .db
 508                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 509                    .await
 510                    .trace_err()
 511                {
 512                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 513                    tracing::info!(
 514                        stale_channel_buffer_count = channel_ids.len(),
 515                        "retrieved stale channel buffers"
 516                    );
 517
 518                    for channel_id in channel_ids {
 519                        if let Some(refreshed_channel_buffer) = app_state
 520                            .db
 521                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 522                            .await
 523                            .trace_err()
 524                        {
 525                            for connection_id in refreshed_channel_buffer.connection_ids {
 526                                peer.send(
 527                                    connection_id,
 528                                    proto::UpdateChannelBufferCollaborators {
 529                                        channel_id: channel_id.to_proto(),
 530                                        collaborators: refreshed_channel_buffer
 531                                            .collaborators
 532                                            .clone(),
 533                                    },
 534                                )
 535                                .trace_err();
 536                            }
 537                        }
 538                    }
 539
 540                    for room_id in room_ids {
 541                        let mut contacts_to_update = HashSet::default();
 542                        let mut canceled_calls_to_user_ids = Vec::new();
 543                        let mut livekit_room = String::new();
 544                        let mut delete_livekit_room = false;
 545
 546                        if let Some(mut refreshed_room) = app_state
 547                            .db
 548                            .clear_stale_room_participants(room_id, server_id)
 549                            .await
 550                            .trace_err()
 551                        {
 552                            tracing::info!(
 553                                room_id = room_id.0,
 554                                new_participant_count = refreshed_room.room.participants.len(),
 555                                "refreshed room"
 556                            );
 557                            room_updated(&refreshed_room.room, &peer);
 558                            if let Some(channel) = refreshed_room.channel.as_ref() {
 559                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 560                            }
 561                            contacts_to_update
 562                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 563                            contacts_to_update
 564                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 565                            canceled_calls_to_user_ids =
 566                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 567                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 568                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 569                        }
 570
 571                        {
 572                            let pool = pool.lock();
 573                            for canceled_user_id in canceled_calls_to_user_ids {
 574                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 575                                    peer.send(
 576                                        connection_id,
 577                                        proto::CallCanceled {
 578                                            room_id: room_id.to_proto(),
 579                                        },
 580                                    )
 581                                    .trace_err();
 582                                }
 583                            }
 584                        }
 585
 586                        for user_id in contacts_to_update {
 587                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 588                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 589                            if let Some((busy, contacts)) = busy.zip(contacts) {
 590                                let pool = pool.lock();
 591                                let updated_contact = contact_for_user(user_id, busy, &pool);
 592                                for contact in contacts {
 593                                    if let db::Contact::Accepted {
 594                                        user_id: contact_user_id,
 595                                        ..
 596                                    } = contact
 597                                    {
 598                                        for contact_conn_id in
 599                                            pool.user_connection_ids(contact_user_id)
 600                                        {
 601                                            peer.send(
 602                                                contact_conn_id,
 603                                                proto::UpdateContacts {
 604                                                    contacts: vec![updated_contact.clone()],
 605                                                    remove_contacts: Default::default(),
 606                                                    incoming_requests: Default::default(),
 607                                                    remove_incoming_requests: Default::default(),
 608                                                    outgoing_requests: Default::default(),
 609                                                    remove_outgoing_requests: Default::default(),
 610                                                },
 611                                            )
 612                                            .trace_err();
 613                                        }
 614                                    }
 615                                }
 616                            }
 617                        }
 618
 619                        if let Some(live_kit) = livekit_client.as_ref()
 620                            && delete_livekit_room {
 621                                live_kit.delete_room(livekit_room).await.trace_err();
 622                            }
 623                    }
 624                }
 625
 626                app_state
 627                    .db
 628                    .delete_stale_channel_chat_participants(
 629                        &app_state.config.zed_environment,
 630                        server_id,
 631                    )
 632                    .await
 633                    .trace_err();
 634
 635                app_state
 636                    .db
 637                    .clear_old_worktree_entries(server_id)
 638                    .await
 639                    .trace_err();
 640
 641                app_state
 642                    .db
 643                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 644                    .await
 645                    .trace_err();
 646            }
 647            .instrument(span),
 648        );
 649        Ok(())
 650    }
 651
 652    pub fn teardown(&self) {
 653        self.peer.teardown();
 654        self.connection_pool.lock().reset();
 655        let _ = self.teardown.send(true);
 656    }
 657
 658    #[cfg(test)]
 659    pub fn reset(&self, id: ServerId) {
 660        self.teardown();
 661        *self.id.lock() = id;
 662        self.peer.reset(id.0 as u32);
 663        let _ = self.teardown.send(false);
 664    }
 665
 666    #[cfg(test)]
 667    pub fn id(&self) -> ServerId {
 668        *self.id.lock()
 669    }
 670
 671    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 672    where
 673        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 674        Fut: 'static + Send + Future<Output = Result<()>>,
 675        M: EnvelopedMessage,
 676    {
 677        let prev_handler = self.handlers.insert(
 678            TypeId::of::<M>(),
 679            Box::new(move |envelope, session, span| {
 680                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 681                let received_at = envelope.received_at;
 682                tracing::info!("message received");
 683                let start_time = Instant::now();
 684                let future = (handler)(
 685                    *envelope,
 686                    MessageContext {
 687                        session,
 688                        span: span.clone(),
 689                    },
 690                );
 691                async move {
 692                    let result = future.await;
 693                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 694                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 695                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 696                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 697                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 698                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 699                    match result {
 700                        Err(error) => {
 701                            tracing::error!(?error, "error handling message")
 702                        }
 703                        Ok(()) => tracing::info!("finished handling message"),
 704                    }
 705                }
 706                .boxed()
 707            }),
 708        );
 709        if prev_handler.is_some() {
 710            panic!("registered a handler for the same message twice");
 711        }
 712        self
 713    }
 714
 715    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 716    where
 717        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 718        Fut: 'static + Send + Future<Output = Result<()>>,
 719        M: EnvelopedMessage,
 720    {
 721        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 722        self
 723    }
 724
 725    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 726    where
 727        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 728        Fut: Send + Future<Output = Result<()>>,
 729        M: RequestMessage,
 730    {
 731        let handler = Arc::new(handler);
 732        self.add_handler(move |envelope, session| {
 733            let receipt = envelope.receipt();
 734            let handler = handler.clone();
 735            async move {
 736                let peer = session.peer.clone();
 737                let responded = Arc::new(AtomicBool::default());
 738                let response = Response {
 739                    peer: peer.clone(),
 740                    responded: responded.clone(),
 741                    receipt,
 742                };
 743                match (handler)(envelope.payload, response, session).await {
 744                    Ok(()) => {
 745                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 746                            Ok(())
 747                        } else {
 748                            Err(anyhow!("handler did not send a response"))?
 749                        }
 750                    }
 751                    Err(error) => {
 752                        let proto_err = match &error {
 753                            Error::Internal(err) => err.to_proto(),
 754                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 755                        };
 756                        peer.respond_with_error(receipt, proto_err)?;
 757                        Err(error)
 758                    }
 759                }
 760            }
 761        })
 762    }
 763
 764    pub fn handle_connection(
 765        self: &Arc<Self>,
 766        connection: Connection,
 767        address: String,
 768        principal: Principal,
 769        zed_version: ZedVersion,
 770        release_channel: Option<String>,
 771        user_agent: Option<String>,
 772        geoip_country_code: Option<String>,
 773        system_id: Option<String>,
 774        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 775        executor: Executor,
 776        connection_guard: Option<ConnectionGuard>,
 777    ) -> impl Future<Output = ()> + use<> {
 778        let this = self.clone();
 779        let span = info_span!("handle connection", %address,
 780            connection_id=field::Empty,
 781            user_id=field::Empty,
 782            login=field::Empty,
 783            impersonator=field::Empty,
 784            user_agent=field::Empty,
 785            geoip_country_code=field::Empty,
 786            release_channel=field::Empty,
 787        );
 788        principal.update_span(&span);
 789        if let Some(user_agent) = user_agent {
 790            span.record("user_agent", user_agent);
 791        }
 792        if let Some(release_channel) = release_channel {
 793            span.record("release_channel", release_channel);
 794        }
 795
 796        if let Some(country_code) = geoip_country_code.as_ref() {
 797            span.record("geoip_country_code", country_code);
 798        }
 799
 800        let mut teardown = self.teardown.subscribe();
 801        async move {
 802            if *teardown.borrow() {
 803                tracing::error!("server is tearing down");
 804                return;
 805            }
 806
 807            let (connection_id, handle_io, mut incoming_rx) =
 808                this.peer.add_connection(connection, {
 809                    let executor = executor.clone();
 810                    move |duration| executor.sleep(duration)
 811                });
 812            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 813
 814            tracing::info!("connection opened");
 815
 816            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 817            let http_client = match ReqwestClient::user_agent(&user_agent) {
 818                Ok(http_client) => Arc::new(http_client),
 819                Err(error) => {
 820                    tracing::error!(?error, "failed to create HTTP client");
 821                    return;
 822                }
 823            };
 824
 825            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 826                |supermaven_admin_api_key| {
 827                    Arc::new(SupermavenAdminApi::new(
 828                        supermaven_admin_api_key.to_string(),
 829                        http_client.clone(),
 830                    ))
 831                },
 832            );
 833
 834            let session = Session {
 835                principal: principal.clone(),
 836                connection_id,
 837                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 838                peer: this.peer.clone(),
 839                connection_pool: this.connection_pool.clone(),
 840                app_state: this.app_state.clone(),
 841                geoip_country_code,
 842                system_id,
 843                _executor: executor.clone(),
 844                supermaven_client,
 845            };
 846
 847            if let Err(error) = this
 848                .send_initial_client_update(
 849                    connection_id,
 850                    zed_version,
 851                    send_connection_id,
 852                    &session,
 853                )
 854                .await
 855            {
 856                tracing::error!(?error, "failed to send initial client update");
 857                return;
 858            }
 859            drop(connection_guard);
 860
 861            let handle_io = handle_io.fuse();
 862            futures::pin_mut!(handle_io);
 863
 864            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 865            // This prevents deadlocks when e.g., client A performs a request to client B and
 866            // client B performs a request to client A. If both clients stop processing further
 867            // messages until their respective request completes, they won't have a chance to
 868            // respond to the other client's request and cause a deadlock.
 869            //
 870            // This arrangement ensures we will attempt to process earlier messages first, but fall
 871            // back to processing messages arrived later in the spirit of making progress.
 872            const MAX_CONCURRENT_HANDLERS: usize = 256;
 873            let mut foreground_message_handlers = FuturesUnordered::new();
 874            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 875            let get_concurrent_handlers = {
 876                let concurrent_handlers = concurrent_handlers.clone();
 877                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 878            };
 879            loop {
 880                let next_message = async {
 881                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 882                    let message = incoming_rx.next().await;
 883                    // Cache the concurrent_handlers here, so that we know what the
 884                    // queue looks like as each handler starts
 885                    (permit, message, get_concurrent_handlers())
 886                }
 887                .fuse();
 888                futures::pin_mut!(next_message);
 889                futures::select_biased! {
 890                    _ = teardown.changed().fuse() => return,
 891                    result = handle_io => {
 892                        if let Err(error) = result {
 893                            tracing::error!(?error, "error handling I/O");
 894                        }
 895                        break;
 896                    }
 897                    _ = foreground_message_handlers.next() => {}
 898                    next_message = next_message => {
 899                        let (permit, message, concurrent_handlers) = next_message;
 900                        if let Some(message) = message {
 901                            let type_name = message.payload_type_name();
 902                            // note: we copy all the fields from the parent span so we can query them in the logs.
 903                            // (https://github.com/tokio-rs/tracing/issues/2670).
 904                            let span = tracing::info_span!("receive message",
 905                                %connection_id,
 906                                %address,
 907                                type_name,
 908                                concurrent_handlers,
 909                                user_id=field::Empty,
 910                                login=field::Empty,
 911                                impersonator=field::Empty,
 912                                multi_lsp_query_request=field::Empty,
 913                                release_channel=field::Empty,
 914                                { TOTAL_DURATION_MS }=field::Empty,
 915                                { PROCESSING_DURATION_MS }=field::Empty,
 916                                { QUEUE_DURATION_MS }=field::Empty,
 917                                { HOST_WAITING_MS }=field::Empty
 918                            );
 919                            principal.update_span(&span);
 920                            let span_enter = span.enter();
 921                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 922                                let is_background = message.is_background();
 923                                let handle_message = (handler)(message, session.clone(), span.clone());
 924                                drop(span_enter);
 925
 926                                let handle_message = async move {
 927                                    handle_message.await;
 928                                    drop(permit);
 929                                }.instrument(span);
 930                                if is_background {
 931                                    executor.spawn_detached(handle_message);
 932                                } else {
 933                                    foreground_message_handlers.push(handle_message);
 934                                }
 935                            } else {
 936                                tracing::error!("no message handler");
 937                            }
 938                        } else {
 939                            tracing::info!("connection closed");
 940                            break;
 941                        }
 942                    }
 943                }
 944            }
 945
 946            drop(foreground_message_handlers);
 947            let concurrent_handlers = get_concurrent_handlers();
 948            tracing::info!(concurrent_handlers, "signing out");
 949            if let Err(error) = connection_lost(session, teardown, executor).await {
 950                tracing::error!(?error, "error signing out");
 951            }
 952        }
 953        .instrument(span)
 954    }
 955
 956    async fn send_initial_client_update(
 957        &self,
 958        connection_id: ConnectionId,
 959        zed_version: ZedVersion,
 960        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 961        session: &Session,
 962    ) -> Result<()> {
 963        self.peer.send(
 964            connection_id,
 965            proto::Hello {
 966                peer_id: Some(connection_id.into()),
 967            },
 968        )?;
 969        tracing::info!("sent hello message");
 970        if let Some(send_connection_id) = send_connection_id.take() {
 971            let _ = send_connection_id.send(connection_id);
 972        }
 973
 974        match &session.principal {
 975            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 976                if !user.connected_once {
 977                    self.peer.send(connection_id, proto::ShowContacts {})?;
 978                    self.app_state
 979                        .db
 980                        .set_user_connected_once(user.id, true)
 981                        .await?;
 982                }
 983
 984                let contacts = self.app_state.db.get_contacts(user.id).await?;
 985
 986                {
 987                    let mut pool = self.connection_pool.lock();
 988                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 989                    self.peer.send(
 990                        connection_id,
 991                        build_initial_contacts_update(contacts, &pool),
 992                    )?;
 993                }
 994
 995                if should_auto_subscribe_to_channels(zed_version) {
 996                    subscribe_user_to_channels(user.id, session).await?;
 997                }
 998
 999                if let Some(incoming_call) =
1000                    self.app_state.db.incoming_call_for_user(user.id).await?
1001                {
1002                    self.peer.send(connection_id, incoming_call)?;
1003                }
1004
1005                update_user_contacts(user.id, session).await?;
1006            }
1007        }
1008
1009        Ok(())
1010    }
1011
1012    pub async fn invite_code_redeemed(
1013        self: &Arc<Self>,
1014        inviter_id: UserId,
1015        invitee_id: UserId,
1016    ) -> Result<()> {
1017        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1018            && let Some(code) = &user.invite_code {
1019                let pool = self.connection_pool.lock();
1020                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1021                for connection_id in pool.user_connection_ids(inviter_id) {
1022                    self.peer.send(
1023                        connection_id,
1024                        proto::UpdateContacts {
1025                            contacts: vec![invitee_contact.clone()],
1026                            ..Default::default()
1027                        },
1028                    )?;
1029                    self.peer.send(
1030                        connection_id,
1031                        proto::UpdateInviteInfo {
1032                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1033                            count: user.invite_count as u32,
1034                        },
1035                    )?;
1036                }
1037            }
1038        Ok(())
1039    }
1040
1041    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1042        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1043            && let Some(invite_code) = &user.invite_code {
1044                let pool = self.connection_pool.lock();
1045                for connection_id in pool.user_connection_ids(user_id) {
1046                    self.peer.send(
1047                        connection_id,
1048                        proto::UpdateInviteInfo {
1049                            url: format!(
1050                                "{}{}",
1051                                self.app_state.config.invite_link_prefix, invite_code
1052                            ),
1053                            count: user.invite_count as u32,
1054                        },
1055                    )?;
1056                }
1057            }
1058        Ok(())
1059    }
1060
1061    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1062        ServerSnapshot {
1063            connection_pool: ConnectionPoolGuard {
1064                guard: self.connection_pool.lock(),
1065                _not_send: PhantomData,
1066            },
1067            peer: &self.peer,
1068        }
1069    }
1070}
1071
1072impl Deref for ConnectionPoolGuard<'_> {
1073    type Target = ConnectionPool;
1074
1075    fn deref(&self) -> &Self::Target {
1076        &self.guard
1077    }
1078}
1079
1080impl DerefMut for ConnectionPoolGuard<'_> {
1081    fn deref_mut(&mut self) -> &mut Self::Target {
1082        &mut self.guard
1083    }
1084}
1085
1086impl Drop for ConnectionPoolGuard<'_> {
1087    fn drop(&mut self) {
1088        #[cfg(test)]
1089        self.check_invariants();
1090    }
1091}
1092
1093fn broadcast<F>(
1094    sender_id: Option<ConnectionId>,
1095    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1096    mut f: F,
1097) where
1098    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1099{
1100    for receiver_id in receiver_ids {
1101        if Some(receiver_id) != sender_id
1102            && let Err(error) = f(receiver_id) {
1103                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1104            }
1105    }
1106}
1107
1108pub struct ProtocolVersion(u32);
1109
1110impl Header for ProtocolVersion {
1111    fn name() -> &'static HeaderName {
1112        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1113        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1114    }
1115
1116    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1117    where
1118        Self: Sized,
1119        I: Iterator<Item = &'i axum::http::HeaderValue>,
1120    {
1121        let version = values
1122            .next()
1123            .ok_or_else(axum::headers::Error::invalid)?
1124            .to_str()
1125            .map_err(|_| axum::headers::Error::invalid())?
1126            .parse()
1127            .map_err(|_| axum::headers::Error::invalid())?;
1128        Ok(Self(version))
1129    }
1130
1131    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1132        values.extend([self.0.to_string().parse().unwrap()]);
1133    }
1134}
1135
1136pub struct AppVersionHeader(SemanticVersion);
1137impl Header for AppVersionHeader {
1138    fn name() -> &'static HeaderName {
1139        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1140        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1141    }
1142
1143    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1144    where
1145        Self: Sized,
1146        I: Iterator<Item = &'i axum::http::HeaderValue>,
1147    {
1148        let version = values
1149            .next()
1150            .ok_or_else(axum::headers::Error::invalid)?
1151            .to_str()
1152            .map_err(|_| axum::headers::Error::invalid())?
1153            .parse()
1154            .map_err(|_| axum::headers::Error::invalid())?;
1155        Ok(Self(version))
1156    }
1157
1158    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1159        values.extend([self.0.to_string().parse().unwrap()]);
1160    }
1161}
1162
1163#[derive(Debug)]
1164pub struct ReleaseChannelHeader(String);
1165
1166impl Header for ReleaseChannelHeader {
1167    fn name() -> &'static HeaderName {
1168        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1169        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1170    }
1171
1172    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1173    where
1174        Self: Sized,
1175        I: Iterator<Item = &'i axum::http::HeaderValue>,
1176    {
1177        Ok(Self(
1178            values
1179                .next()
1180                .ok_or_else(axum::headers::Error::invalid)?
1181                .to_str()
1182                .map_err(|_| axum::headers::Error::invalid())?
1183                .to_owned(),
1184        ))
1185    }
1186
1187    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1188        values.extend([self.0.parse().unwrap()]);
1189    }
1190}
1191
1192pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1193    Router::new()
1194        .route("/rpc", get(handle_websocket_request))
1195        .layer(
1196            ServiceBuilder::new()
1197                .layer(Extension(server.app_state.clone()))
1198                .layer(middleware::from_fn(auth::validate_header)),
1199        )
1200        .route("/metrics", get(handle_metrics))
1201        .layer(Extension(server))
1202}
1203
1204pub async fn handle_websocket_request(
1205    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1206    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1207    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1208    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1209    Extension(server): Extension<Arc<Server>>,
1210    Extension(principal): Extension<Principal>,
1211    user_agent: Option<TypedHeader<UserAgent>>,
1212    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1213    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1214    ws: WebSocketUpgrade,
1215) -> axum::response::Response {
1216    if protocol_version != rpc::PROTOCOL_VERSION {
1217        return (
1218            StatusCode::UPGRADE_REQUIRED,
1219            "client must be upgraded".to_string(),
1220        )
1221            .into_response();
1222    }
1223
1224    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1225        return (
1226            StatusCode::UPGRADE_REQUIRED,
1227            "no version header found".to_string(),
1228        )
1229            .into_response();
1230    };
1231
1232    let release_channel = release_channel_header.map(|header| header.0.0);
1233
1234    if !version.can_collaborate() {
1235        return (
1236            StatusCode::UPGRADE_REQUIRED,
1237            "client must be upgraded".to_string(),
1238        )
1239            .into_response();
1240    }
1241
1242    let socket_address = socket_address.to_string();
1243
1244    // Acquire connection guard before WebSocket upgrade
1245    let connection_guard = match ConnectionGuard::try_acquire() {
1246        Ok(guard) => guard,
1247        Err(()) => {
1248            return (
1249                StatusCode::SERVICE_UNAVAILABLE,
1250                "Too many concurrent connections",
1251            )
1252                .into_response();
1253        }
1254    };
1255
1256    ws.on_upgrade(move |socket| {
1257        let socket = socket
1258            .map_ok(to_tungstenite_message)
1259            .err_into()
1260            .with(|message| async move { to_axum_message(message) });
1261        let connection = Connection::new(Box::pin(socket));
1262        async move {
1263            server
1264                .handle_connection(
1265                    connection,
1266                    socket_address,
1267                    principal,
1268                    version,
1269                    release_channel,
1270                    user_agent.map(|header| header.to_string()),
1271                    country_code_header.map(|header| header.to_string()),
1272                    system_id_header.map(|header| header.to_string()),
1273                    None,
1274                    Executor::Production,
1275                    Some(connection_guard),
1276                )
1277                .await;
1278        }
1279    })
1280}
1281
1282pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1283    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1284    let connections_metric = CONNECTIONS_METRIC
1285        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1286
1287    let connections = server
1288        .connection_pool
1289        .lock()
1290        .connections()
1291        .filter(|connection| !connection.admin)
1292        .count();
1293    connections_metric.set(connections as _);
1294
1295    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1296    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1297        register_int_gauge!(
1298            "shared_projects",
1299            "number of open projects with one or more guests"
1300        )
1301        .unwrap()
1302    });
1303
1304    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1305    shared_projects_metric.set(shared_projects as _);
1306
1307    let encoder = prometheus::TextEncoder::new();
1308    let metric_families = prometheus::gather();
1309    let encoded_metrics = encoder
1310        .encode_to_string(&metric_families)
1311        .map_err(|err| anyhow!("{err}"))?;
1312    Ok(encoded_metrics)
1313}
1314
1315#[instrument(err, skip(executor))]
1316async fn connection_lost(
1317    session: Session,
1318    mut teardown: watch::Receiver<bool>,
1319    executor: Executor,
1320) -> Result<()> {
1321    session.peer.disconnect(session.connection_id);
1322    session
1323        .connection_pool()
1324        .await
1325        .remove_connection(session.connection_id)?;
1326
1327    session
1328        .db()
1329        .await
1330        .connection_lost(session.connection_id)
1331        .await
1332        .trace_err();
1333
1334    futures::select_biased! {
1335        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1336
1337            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1338            leave_room_for_session(&session, session.connection_id).await.trace_err();
1339            leave_channel_buffers_for_session(&session)
1340                .await
1341                .trace_err();
1342
1343            if !session
1344                .connection_pool()
1345                .await
1346                .is_user_online(session.user_id())
1347            {
1348                let db = session.db().await;
1349                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1350                    room_updated(&room, &session.peer);
1351                }
1352            }
1353
1354            update_user_contacts(session.user_id(), &session).await?;
1355        },
1356        _ = teardown.changed().fuse() => {}
1357    }
1358
1359    Ok(())
1360}
1361
1362/// Acknowledges a ping from a client, used to keep the connection alive.
1363async fn ping(
1364    _: proto::Ping,
1365    response: Response<proto::Ping>,
1366    _session: MessageContext,
1367) -> Result<()> {
1368    response.send(proto::Ack {})?;
1369    Ok(())
1370}
1371
1372/// Creates a new room for calling (outside of channels)
1373async fn create_room(
1374    _request: proto::CreateRoom,
1375    response: Response<proto::CreateRoom>,
1376    session: MessageContext,
1377) -> Result<()> {
1378    let livekit_room = nanoid::nanoid!(30);
1379
1380    let live_kit_connection_info = util::maybe!(async {
1381        let live_kit = session.app_state.livekit_client.as_ref();
1382        let live_kit = live_kit?;
1383        let user_id = session.user_id().to_string();
1384
1385        let token = live_kit
1386            .room_token(&livekit_room, &user_id.to_string())
1387            .trace_err()?;
1388
1389        Some(proto::LiveKitConnectionInfo {
1390            server_url: live_kit.url().into(),
1391            token,
1392            can_publish: true,
1393        })
1394    })
1395    .await;
1396
1397    let room = session
1398        .db()
1399        .await
1400        .create_room(session.user_id(), session.connection_id, &livekit_room)
1401        .await?;
1402
1403    response.send(proto::CreateRoomResponse {
1404        room: Some(room.clone()),
1405        live_kit_connection_info,
1406    })?;
1407
1408    update_user_contacts(session.user_id(), &session).await?;
1409    Ok(())
1410}
1411
1412/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1413async fn join_room(
1414    request: proto::JoinRoom,
1415    response: Response<proto::JoinRoom>,
1416    session: MessageContext,
1417) -> Result<()> {
1418    let room_id = RoomId::from_proto(request.id);
1419
1420    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1421
1422    if let Some(channel_id) = channel_id {
1423        return join_channel_internal(channel_id, Box::new(response), session).await;
1424    }
1425
1426    let joined_room = {
1427        let room = session
1428            .db()
1429            .await
1430            .join_room(room_id, session.user_id(), session.connection_id)
1431            .await?;
1432        room_updated(&room.room, &session.peer);
1433        room.into_inner()
1434    };
1435
1436    for connection_id in session
1437        .connection_pool()
1438        .await
1439        .user_connection_ids(session.user_id())
1440    {
1441        session
1442            .peer
1443            .send(
1444                connection_id,
1445                proto::CallCanceled {
1446                    room_id: room_id.to_proto(),
1447                },
1448            )
1449            .trace_err();
1450    }
1451
1452    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1453    {
1454        live_kit
1455            .room_token(
1456                &joined_room.room.livekit_room,
1457                &session.user_id().to_string(),
1458            )
1459            .trace_err()
1460            .map(|token| proto::LiveKitConnectionInfo {
1461                server_url: live_kit.url().into(),
1462                token,
1463                can_publish: true,
1464            })
1465    } else {
1466        None
1467    };
1468
1469    response.send(proto::JoinRoomResponse {
1470        room: Some(joined_room.room),
1471        channel_id: None,
1472        live_kit_connection_info,
1473    })?;
1474
1475    update_user_contacts(session.user_id(), &session).await?;
1476    Ok(())
1477}
1478
1479/// Rejoin room is used to reconnect to a room after connection errors.
1480async fn rejoin_room(
1481    request: proto::RejoinRoom,
1482    response: Response<proto::RejoinRoom>,
1483    session: MessageContext,
1484) -> Result<()> {
1485    let room;
1486    let channel;
1487    {
1488        let mut rejoined_room = session
1489            .db()
1490            .await
1491            .rejoin_room(request, session.user_id(), session.connection_id)
1492            .await?;
1493
1494        response.send(proto::RejoinRoomResponse {
1495            room: Some(rejoined_room.room.clone()),
1496            reshared_projects: rejoined_room
1497                .reshared_projects
1498                .iter()
1499                .map(|project| proto::ResharedProject {
1500                    id: project.id.to_proto(),
1501                    collaborators: project
1502                        .collaborators
1503                        .iter()
1504                        .map(|collaborator| collaborator.to_proto())
1505                        .collect(),
1506                })
1507                .collect(),
1508            rejoined_projects: rejoined_room
1509                .rejoined_projects
1510                .iter()
1511                .map(|rejoined_project| rejoined_project.to_proto())
1512                .collect(),
1513        })?;
1514        room_updated(&rejoined_room.room, &session.peer);
1515
1516        for project in &rejoined_room.reshared_projects {
1517            for collaborator in &project.collaborators {
1518                session
1519                    .peer
1520                    .send(
1521                        collaborator.connection_id,
1522                        proto::UpdateProjectCollaborator {
1523                            project_id: project.id.to_proto(),
1524                            old_peer_id: Some(project.old_connection_id.into()),
1525                            new_peer_id: Some(session.connection_id.into()),
1526                        },
1527                    )
1528                    .trace_err();
1529            }
1530
1531            broadcast(
1532                Some(session.connection_id),
1533                project
1534                    .collaborators
1535                    .iter()
1536                    .map(|collaborator| collaborator.connection_id),
1537                |connection_id| {
1538                    session.peer.forward_send(
1539                        session.connection_id,
1540                        connection_id,
1541                        proto::UpdateProject {
1542                            project_id: project.id.to_proto(),
1543                            worktrees: project.worktrees.clone(),
1544                        },
1545                    )
1546                },
1547            );
1548        }
1549
1550        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1551
1552        let rejoined_room = rejoined_room.into_inner();
1553
1554        room = rejoined_room.room;
1555        channel = rejoined_room.channel;
1556    }
1557
1558    if let Some(channel) = channel {
1559        channel_updated(
1560            &channel,
1561            &room,
1562            &session.peer,
1563            &*session.connection_pool().await,
1564        );
1565    }
1566
1567    update_user_contacts(session.user_id(), &session).await?;
1568    Ok(())
1569}
1570
1571fn notify_rejoined_projects(
1572    rejoined_projects: &mut Vec<RejoinedProject>,
1573    session: &Session,
1574) -> Result<()> {
1575    for project in rejoined_projects.iter() {
1576        for collaborator in &project.collaborators {
1577            session
1578                .peer
1579                .send(
1580                    collaborator.connection_id,
1581                    proto::UpdateProjectCollaborator {
1582                        project_id: project.id.to_proto(),
1583                        old_peer_id: Some(project.old_connection_id.into()),
1584                        new_peer_id: Some(session.connection_id.into()),
1585                    },
1586                )
1587                .trace_err();
1588        }
1589    }
1590
1591    for project in rejoined_projects {
1592        for worktree in mem::take(&mut project.worktrees) {
1593            // Stream this worktree's entries.
1594            let message = proto::UpdateWorktree {
1595                project_id: project.id.to_proto(),
1596                worktree_id: worktree.id,
1597                abs_path: worktree.abs_path.clone(),
1598                root_name: worktree.root_name,
1599                updated_entries: worktree.updated_entries,
1600                removed_entries: worktree.removed_entries,
1601                scan_id: worktree.scan_id,
1602                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1603                updated_repositories: worktree.updated_repositories,
1604                removed_repositories: worktree.removed_repositories,
1605            };
1606            for update in proto::split_worktree_update(message) {
1607                session.peer.send(session.connection_id, update)?;
1608            }
1609
1610            // Stream this worktree's diagnostics.
1611            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1612            if let Some(summary) = worktree_diagnostics.next() {
1613                let message = proto::UpdateDiagnosticSummary {
1614                    project_id: project.id.to_proto(),
1615                    worktree_id: worktree.id,
1616                    summary: Some(summary),
1617                    more_summaries: worktree_diagnostics.collect(),
1618                };
1619                session.peer.send(session.connection_id, message)?;
1620            }
1621
1622            for settings_file in worktree.settings_files {
1623                session.peer.send(
1624                    session.connection_id,
1625                    proto::UpdateWorktreeSettings {
1626                        project_id: project.id.to_proto(),
1627                        worktree_id: worktree.id,
1628                        path: settings_file.path,
1629                        content: Some(settings_file.content),
1630                        kind: Some(settings_file.kind.to_proto().into()),
1631                    },
1632                )?;
1633            }
1634        }
1635
1636        for repository in mem::take(&mut project.updated_repositories) {
1637            for update in split_repository_update(repository) {
1638                session.peer.send(session.connection_id, update)?;
1639            }
1640        }
1641
1642        for id in mem::take(&mut project.removed_repositories) {
1643            session.peer.send(
1644                session.connection_id,
1645                proto::RemoveRepository {
1646                    project_id: project.id.to_proto(),
1647                    id,
1648                },
1649            )?;
1650        }
1651    }
1652
1653    Ok(())
1654}
1655
1656/// leave room disconnects from the room.
1657async fn leave_room(
1658    _: proto::LeaveRoom,
1659    response: Response<proto::LeaveRoom>,
1660    session: MessageContext,
1661) -> Result<()> {
1662    leave_room_for_session(&session, session.connection_id).await?;
1663    response.send(proto::Ack {})?;
1664    Ok(())
1665}
1666
1667/// Updates the permissions of someone else in the room.
1668async fn set_room_participant_role(
1669    request: proto::SetRoomParticipantRole,
1670    response: Response<proto::SetRoomParticipantRole>,
1671    session: MessageContext,
1672) -> Result<()> {
1673    let user_id = UserId::from_proto(request.user_id);
1674    let role = ChannelRole::from(request.role());
1675
1676    let (livekit_room, can_publish) = {
1677        let room = session
1678            .db()
1679            .await
1680            .set_room_participant_role(
1681                session.user_id(),
1682                RoomId::from_proto(request.room_id),
1683                user_id,
1684                role,
1685            )
1686            .await?;
1687
1688        let livekit_room = room.livekit_room.clone();
1689        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1690        room_updated(&room, &session.peer);
1691        (livekit_room, can_publish)
1692    };
1693
1694    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1695        live_kit
1696            .update_participant(
1697                livekit_room.clone(),
1698                request.user_id.to_string(),
1699                livekit_api::proto::ParticipantPermission {
1700                    can_subscribe: true,
1701                    can_publish,
1702                    can_publish_data: can_publish,
1703                    hidden: false,
1704                    recorder: false,
1705                },
1706            )
1707            .await
1708            .trace_err();
1709    }
1710
1711    response.send(proto::Ack {})?;
1712    Ok(())
1713}
1714
1715/// Call someone else into the current room
1716async fn call(
1717    request: proto::Call,
1718    response: Response<proto::Call>,
1719    session: MessageContext,
1720) -> Result<()> {
1721    let room_id = RoomId::from_proto(request.room_id);
1722    let calling_user_id = session.user_id();
1723    let calling_connection_id = session.connection_id;
1724    let called_user_id = UserId::from_proto(request.called_user_id);
1725    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1726    if !session
1727        .db()
1728        .await
1729        .has_contact(calling_user_id, called_user_id)
1730        .await?
1731    {
1732        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1733    }
1734
1735    let incoming_call = {
1736        let (room, incoming_call) = &mut *session
1737            .db()
1738            .await
1739            .call(
1740                room_id,
1741                calling_user_id,
1742                calling_connection_id,
1743                called_user_id,
1744                initial_project_id,
1745            )
1746            .await?;
1747        room_updated(room, &session.peer);
1748        mem::take(incoming_call)
1749    };
1750    update_user_contacts(called_user_id, &session).await?;
1751
1752    let mut calls = session
1753        .connection_pool()
1754        .await
1755        .user_connection_ids(called_user_id)
1756        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1757        .collect::<FuturesUnordered<_>>();
1758
1759    while let Some(call_response) = calls.next().await {
1760        match call_response.as_ref() {
1761            Ok(_) => {
1762                response.send(proto::Ack {})?;
1763                return Ok(());
1764            }
1765            Err(_) => {
1766                call_response.trace_err();
1767            }
1768        }
1769    }
1770
1771    {
1772        let room = session
1773            .db()
1774            .await
1775            .call_failed(room_id, called_user_id)
1776            .await?;
1777        room_updated(&room, &session.peer);
1778    }
1779    update_user_contacts(called_user_id, &session).await?;
1780
1781    Err(anyhow!("failed to ring user"))?
1782}
1783
1784/// Cancel an outgoing call.
1785async fn cancel_call(
1786    request: proto::CancelCall,
1787    response: Response<proto::CancelCall>,
1788    session: MessageContext,
1789) -> Result<()> {
1790    let called_user_id = UserId::from_proto(request.called_user_id);
1791    let room_id = RoomId::from_proto(request.room_id);
1792    {
1793        let room = session
1794            .db()
1795            .await
1796            .cancel_call(room_id, session.connection_id, called_user_id)
1797            .await?;
1798        room_updated(&room, &session.peer);
1799    }
1800
1801    for connection_id in session
1802        .connection_pool()
1803        .await
1804        .user_connection_ids(called_user_id)
1805    {
1806        session
1807            .peer
1808            .send(
1809                connection_id,
1810                proto::CallCanceled {
1811                    room_id: room_id.to_proto(),
1812                },
1813            )
1814            .trace_err();
1815    }
1816    response.send(proto::Ack {})?;
1817
1818    update_user_contacts(called_user_id, &session).await?;
1819    Ok(())
1820}
1821
1822/// Decline an incoming call.
1823async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1824    let room_id = RoomId::from_proto(message.room_id);
1825    {
1826        let room = session
1827            .db()
1828            .await
1829            .decline_call(Some(room_id), session.user_id())
1830            .await?
1831            .context("declining call")?;
1832        room_updated(&room, &session.peer);
1833    }
1834
1835    for connection_id in session
1836        .connection_pool()
1837        .await
1838        .user_connection_ids(session.user_id())
1839    {
1840        session
1841            .peer
1842            .send(
1843                connection_id,
1844                proto::CallCanceled {
1845                    room_id: room_id.to_proto(),
1846                },
1847            )
1848            .trace_err();
1849    }
1850    update_user_contacts(session.user_id(), &session).await?;
1851    Ok(())
1852}
1853
1854/// Updates other participants in the room with your current location.
1855async fn update_participant_location(
1856    request: proto::UpdateParticipantLocation,
1857    response: Response<proto::UpdateParticipantLocation>,
1858    session: MessageContext,
1859) -> Result<()> {
1860    let room_id = RoomId::from_proto(request.room_id);
1861    let location = request.location.context("invalid location")?;
1862
1863    let db = session.db().await;
1864    let room = db
1865        .update_room_participant_location(room_id, session.connection_id, location)
1866        .await?;
1867
1868    room_updated(&room, &session.peer);
1869    response.send(proto::Ack {})?;
1870    Ok(())
1871}
1872
1873/// Share a project into the room.
1874async fn share_project(
1875    request: proto::ShareProject,
1876    response: Response<proto::ShareProject>,
1877    session: MessageContext,
1878) -> Result<()> {
1879    let (project_id, room) = &*session
1880        .db()
1881        .await
1882        .share_project(
1883            RoomId::from_proto(request.room_id),
1884            session.connection_id,
1885            &request.worktrees,
1886            request.is_ssh_project,
1887        )
1888        .await?;
1889    response.send(proto::ShareProjectResponse {
1890        project_id: project_id.to_proto(),
1891    })?;
1892    room_updated(room, &session.peer);
1893
1894    Ok(())
1895}
1896
1897/// Unshare a project from the room.
1898async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1899    let project_id = ProjectId::from_proto(message.project_id);
1900    unshare_project_internal(project_id, session.connection_id, &session).await
1901}
1902
1903async fn unshare_project_internal(
1904    project_id: ProjectId,
1905    connection_id: ConnectionId,
1906    session: &Session,
1907) -> Result<()> {
1908    let delete = {
1909        let room_guard = session
1910            .db()
1911            .await
1912            .unshare_project(project_id, connection_id)
1913            .await?;
1914
1915        let (delete, room, guest_connection_ids) = &*room_guard;
1916
1917        let message = proto::UnshareProject {
1918            project_id: project_id.to_proto(),
1919        };
1920
1921        broadcast(
1922            Some(connection_id),
1923            guest_connection_ids.iter().copied(),
1924            |conn_id| session.peer.send(conn_id, message.clone()),
1925        );
1926        if let Some(room) = room {
1927            room_updated(room, &session.peer);
1928        }
1929
1930        *delete
1931    };
1932
1933    if delete {
1934        let db = session.db().await;
1935        db.delete_project(project_id).await?;
1936    }
1937
1938    Ok(())
1939}
1940
1941/// Join someone elses shared project.
1942async fn join_project(
1943    request: proto::JoinProject,
1944    response: Response<proto::JoinProject>,
1945    session: MessageContext,
1946) -> Result<()> {
1947    let project_id = ProjectId::from_proto(request.project_id);
1948
1949    tracing::info!(%project_id, "join project");
1950
1951    let db = session.db().await;
1952    let (project, replica_id) = &mut *db
1953        .join_project(
1954            project_id,
1955            session.connection_id,
1956            session.user_id(),
1957            request.committer_name.clone(),
1958            request.committer_email.clone(),
1959        )
1960        .await?;
1961    drop(db);
1962    tracing::info!(%project_id, "join remote project");
1963    let collaborators = project
1964        .collaborators
1965        .iter()
1966        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1967        .map(|collaborator| collaborator.to_proto())
1968        .collect::<Vec<_>>();
1969    let project_id = project.id;
1970    let guest_user_id = session.user_id();
1971
1972    let worktrees = project
1973        .worktrees
1974        .iter()
1975        .map(|(id, worktree)| proto::WorktreeMetadata {
1976            id: *id,
1977            root_name: worktree.root_name.clone(),
1978            visible: worktree.visible,
1979            abs_path: worktree.abs_path.clone(),
1980        })
1981        .collect::<Vec<_>>();
1982
1983    let add_project_collaborator = proto::AddProjectCollaborator {
1984        project_id: project_id.to_proto(),
1985        collaborator: Some(proto::Collaborator {
1986            peer_id: Some(session.connection_id.into()),
1987            replica_id: replica_id.0 as u32,
1988            user_id: guest_user_id.to_proto(),
1989            is_host: false,
1990            committer_name: request.committer_name.clone(),
1991            committer_email: request.committer_email.clone(),
1992        }),
1993    };
1994
1995    for collaborator in &collaborators {
1996        session
1997            .peer
1998            .send(
1999                collaborator.peer_id.unwrap().into(),
2000                add_project_collaborator.clone(),
2001            )
2002            .trace_err();
2003    }
2004
2005    // First, we send the metadata associated with each worktree.
2006    let (language_servers, language_server_capabilities) = project
2007        .language_servers
2008        .clone()
2009        .into_iter()
2010        .map(|server| (server.server, server.capabilities))
2011        .unzip();
2012    response.send(proto::JoinProjectResponse {
2013        project_id: project.id.0 as u64,
2014        worktrees: worktrees.clone(),
2015        replica_id: replica_id.0 as u32,
2016        collaborators: collaborators.clone(),
2017        language_servers,
2018        language_server_capabilities,
2019        role: project.role.into(),
2020    })?;
2021
2022    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2023        // Stream this worktree's entries.
2024        let message = proto::UpdateWorktree {
2025            project_id: project_id.to_proto(),
2026            worktree_id,
2027            abs_path: worktree.abs_path.clone(),
2028            root_name: worktree.root_name,
2029            updated_entries: worktree.entries,
2030            removed_entries: Default::default(),
2031            scan_id: worktree.scan_id,
2032            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2033            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2034            removed_repositories: Default::default(),
2035        };
2036        for update in proto::split_worktree_update(message) {
2037            session.peer.send(session.connection_id, update.clone())?;
2038        }
2039
2040        // Stream this worktree's diagnostics.
2041        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2042        if let Some(summary) = worktree_diagnostics.next() {
2043            let message = proto::UpdateDiagnosticSummary {
2044                project_id: project.id.to_proto(),
2045                worktree_id: worktree.id,
2046                summary: Some(summary),
2047                more_summaries: worktree_diagnostics.collect(),
2048            };
2049            session.peer.send(session.connection_id, message)?;
2050        }
2051
2052        for settings_file in worktree.settings_files {
2053            session.peer.send(
2054                session.connection_id,
2055                proto::UpdateWorktreeSettings {
2056                    project_id: project_id.to_proto(),
2057                    worktree_id: worktree.id,
2058                    path: settings_file.path,
2059                    content: Some(settings_file.content),
2060                    kind: Some(settings_file.kind.to_proto() as i32),
2061                },
2062            )?;
2063        }
2064    }
2065
2066    for repository in mem::take(&mut project.repositories) {
2067        for update in split_repository_update(repository) {
2068            session.peer.send(session.connection_id, update)?;
2069        }
2070    }
2071
2072    for language_server in &project.language_servers {
2073        session.peer.send(
2074            session.connection_id,
2075            proto::UpdateLanguageServer {
2076                project_id: project_id.to_proto(),
2077                server_name: Some(language_server.server.name.clone()),
2078                language_server_id: language_server.server.id,
2079                variant: Some(
2080                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2081                        proto::LspDiskBasedDiagnosticsUpdated {},
2082                    ),
2083                ),
2084            },
2085        )?;
2086    }
2087
2088    Ok(())
2089}
2090
2091/// Leave someone elses shared project.
2092async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2093    let sender_id = session.connection_id;
2094    let project_id = ProjectId::from_proto(request.project_id);
2095    let db = session.db().await;
2096
2097    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2098    tracing::info!(
2099        %project_id,
2100        "leave project"
2101    );
2102
2103    project_left(project, &session);
2104    if let Some(room) = room {
2105        room_updated(room, &session.peer);
2106    }
2107
2108    Ok(())
2109}
2110
2111/// Updates other participants with changes to the project
2112async fn update_project(
2113    request: proto::UpdateProject,
2114    response: Response<proto::UpdateProject>,
2115    session: MessageContext,
2116) -> Result<()> {
2117    let project_id = ProjectId::from_proto(request.project_id);
2118    let (room, guest_connection_ids) = &*session
2119        .db()
2120        .await
2121        .update_project(project_id, session.connection_id, &request.worktrees)
2122        .await?;
2123    broadcast(
2124        Some(session.connection_id),
2125        guest_connection_ids.iter().copied(),
2126        |connection_id| {
2127            session
2128                .peer
2129                .forward_send(session.connection_id, connection_id, request.clone())
2130        },
2131    );
2132    if let Some(room) = room {
2133        room_updated(room, &session.peer);
2134    }
2135    response.send(proto::Ack {})?;
2136
2137    Ok(())
2138}
2139
2140/// Updates other participants with changes to the worktree
2141async fn update_worktree(
2142    request: proto::UpdateWorktree,
2143    response: Response<proto::UpdateWorktree>,
2144    session: MessageContext,
2145) -> Result<()> {
2146    let guest_connection_ids = session
2147        .db()
2148        .await
2149        .update_worktree(&request, session.connection_id)
2150        .await?;
2151
2152    broadcast(
2153        Some(session.connection_id),
2154        guest_connection_ids.iter().copied(),
2155        |connection_id| {
2156            session
2157                .peer
2158                .forward_send(session.connection_id, connection_id, request.clone())
2159        },
2160    );
2161    response.send(proto::Ack {})?;
2162    Ok(())
2163}
2164
2165async fn update_repository(
2166    request: proto::UpdateRepository,
2167    response: Response<proto::UpdateRepository>,
2168    session: MessageContext,
2169) -> Result<()> {
2170    let guest_connection_ids = session
2171        .db()
2172        .await
2173        .update_repository(&request, session.connection_id)
2174        .await?;
2175
2176    broadcast(
2177        Some(session.connection_id),
2178        guest_connection_ids.iter().copied(),
2179        |connection_id| {
2180            session
2181                .peer
2182                .forward_send(session.connection_id, connection_id, request.clone())
2183        },
2184    );
2185    response.send(proto::Ack {})?;
2186    Ok(())
2187}
2188
2189async fn remove_repository(
2190    request: proto::RemoveRepository,
2191    response: Response<proto::RemoveRepository>,
2192    session: MessageContext,
2193) -> Result<()> {
2194    let guest_connection_ids = session
2195        .db()
2196        .await
2197        .remove_repository(&request, session.connection_id)
2198        .await?;
2199
2200    broadcast(
2201        Some(session.connection_id),
2202        guest_connection_ids.iter().copied(),
2203        |connection_id| {
2204            session
2205                .peer
2206                .forward_send(session.connection_id, connection_id, request.clone())
2207        },
2208    );
2209    response.send(proto::Ack {})?;
2210    Ok(())
2211}
2212
2213/// Updates other participants with changes to the diagnostics
2214async fn update_diagnostic_summary(
2215    message: proto::UpdateDiagnosticSummary,
2216    session: MessageContext,
2217) -> Result<()> {
2218    let guest_connection_ids = session
2219        .db()
2220        .await
2221        .update_diagnostic_summary(&message, session.connection_id)
2222        .await?;
2223
2224    broadcast(
2225        Some(session.connection_id),
2226        guest_connection_ids.iter().copied(),
2227        |connection_id| {
2228            session
2229                .peer
2230                .forward_send(session.connection_id, connection_id, message.clone())
2231        },
2232    );
2233
2234    Ok(())
2235}
2236
2237/// Updates other participants with changes to the worktree settings
2238async fn update_worktree_settings(
2239    message: proto::UpdateWorktreeSettings,
2240    session: MessageContext,
2241) -> Result<()> {
2242    let guest_connection_ids = session
2243        .db()
2244        .await
2245        .update_worktree_settings(&message, session.connection_id)
2246        .await?;
2247
2248    broadcast(
2249        Some(session.connection_id),
2250        guest_connection_ids.iter().copied(),
2251        |connection_id| {
2252            session
2253                .peer
2254                .forward_send(session.connection_id, connection_id, message.clone())
2255        },
2256    );
2257
2258    Ok(())
2259}
2260
2261/// Notify other participants that a language server has started.
2262async fn start_language_server(
2263    request: proto::StartLanguageServer,
2264    session: MessageContext,
2265) -> Result<()> {
2266    let guest_connection_ids = session
2267        .db()
2268        .await
2269        .start_language_server(&request, session.connection_id)
2270        .await?;
2271
2272    broadcast(
2273        Some(session.connection_id),
2274        guest_connection_ids.iter().copied(),
2275        |connection_id| {
2276            session
2277                .peer
2278                .forward_send(session.connection_id, connection_id, request.clone())
2279        },
2280    );
2281    Ok(())
2282}
2283
2284/// Notify other participants that a language server has changed.
2285async fn update_language_server(
2286    request: proto::UpdateLanguageServer,
2287    session: MessageContext,
2288) -> Result<()> {
2289    let project_id = ProjectId::from_proto(request.project_id);
2290    let db = session.db().await;
2291
2292    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2293        && let Some(capabilities) = update.capabilities.clone() {
2294            db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2295                .await?;
2296        }
2297
2298    let project_connection_ids = db
2299        .project_connection_ids(project_id, session.connection_id, true)
2300        .await?;
2301    broadcast(
2302        Some(session.connection_id),
2303        project_connection_ids.iter().copied(),
2304        |connection_id| {
2305            session
2306                .peer
2307                .forward_send(session.connection_id, connection_id, request.clone())
2308        },
2309    );
2310    Ok(())
2311}
2312
2313/// forward a project request to the host. These requests should be read only
2314/// as guests are allowed to send them.
2315async fn forward_read_only_project_request<T>(
2316    request: T,
2317    response: Response<T>,
2318    session: MessageContext,
2319) -> Result<()>
2320where
2321    T: EntityMessage + RequestMessage,
2322{
2323    let project_id = ProjectId::from_proto(request.remote_entity_id());
2324    let host_connection_id = session
2325        .db()
2326        .await
2327        .host_for_read_only_project_request(project_id, session.connection_id)
2328        .await?;
2329    let payload = session.forward_request(host_connection_id, request).await?;
2330    response.send(payload)?;
2331    Ok(())
2332}
2333
2334/// forward a project request to the host. These requests are disallowed
2335/// for guests.
2336async fn forward_mutating_project_request<T>(
2337    request: T,
2338    response: Response<T>,
2339    session: MessageContext,
2340) -> Result<()>
2341where
2342    T: EntityMessage + RequestMessage,
2343{
2344    let project_id = ProjectId::from_proto(request.remote_entity_id());
2345
2346    let host_connection_id = session
2347        .db()
2348        .await
2349        .host_for_mutating_project_request(project_id, session.connection_id)
2350        .await?;
2351    let payload = session.forward_request(host_connection_id, request).await?;
2352    response.send(payload)?;
2353    Ok(())
2354}
2355
2356async fn multi_lsp_query(
2357    request: MultiLspQuery,
2358    response: Response<MultiLspQuery>,
2359    session: MessageContext,
2360) -> Result<()> {
2361    tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2362    tracing::info!("multi_lsp_query message received");
2363    forward_mutating_project_request(request, response, session).await
2364}
2365
2366/// Notify other participants that a new buffer has been created
2367async fn create_buffer_for_peer(
2368    request: proto::CreateBufferForPeer,
2369    session: MessageContext,
2370) -> Result<()> {
2371    session
2372        .db()
2373        .await
2374        .check_user_is_project_host(
2375            ProjectId::from_proto(request.project_id),
2376            session.connection_id,
2377        )
2378        .await?;
2379    let peer_id = request.peer_id.context("invalid peer id")?;
2380    session
2381        .peer
2382        .forward_send(session.connection_id, peer_id.into(), request)?;
2383    Ok(())
2384}
2385
2386/// Notify other participants that a buffer has been updated. This is
2387/// allowed for guests as long as the update is limited to selections.
2388async fn update_buffer(
2389    request: proto::UpdateBuffer,
2390    response: Response<proto::UpdateBuffer>,
2391    session: MessageContext,
2392) -> Result<()> {
2393    let project_id = ProjectId::from_proto(request.project_id);
2394    let mut capability = Capability::ReadOnly;
2395
2396    for op in request.operations.iter() {
2397        match op.variant {
2398            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2399            Some(_) => capability = Capability::ReadWrite,
2400        }
2401    }
2402
2403    let host = {
2404        let guard = session
2405            .db()
2406            .await
2407            .connections_for_buffer_update(project_id, session.connection_id, capability)
2408            .await?;
2409
2410        let (host, guests) = &*guard;
2411
2412        broadcast(
2413            Some(session.connection_id),
2414            guests.clone(),
2415            |connection_id| {
2416                session
2417                    .peer
2418                    .forward_send(session.connection_id, connection_id, request.clone())
2419            },
2420        );
2421
2422        *host
2423    };
2424
2425    if host != session.connection_id {
2426        session.forward_request(host, request.clone()).await?;
2427    }
2428
2429    response.send(proto::Ack {})?;
2430    Ok(())
2431}
2432
2433async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2434    let project_id = ProjectId::from_proto(message.project_id);
2435
2436    let operation = message.operation.as_ref().context("invalid operation")?;
2437    let capability = match operation.variant.as_ref() {
2438        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2439            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2440                match buffer_op.variant {
2441                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2442                        Capability::ReadOnly
2443                    }
2444                    _ => Capability::ReadWrite,
2445                }
2446            } else {
2447                Capability::ReadWrite
2448            }
2449        }
2450        Some(_) => Capability::ReadWrite,
2451        None => Capability::ReadOnly,
2452    };
2453
2454    let guard = session
2455        .db()
2456        .await
2457        .connections_for_buffer_update(project_id, session.connection_id, capability)
2458        .await?;
2459
2460    let (host, guests) = &*guard;
2461
2462    broadcast(
2463        Some(session.connection_id),
2464        guests.iter().chain([host]).copied(),
2465        |connection_id| {
2466            session
2467                .peer
2468                .forward_send(session.connection_id, connection_id, message.clone())
2469        },
2470    );
2471
2472    Ok(())
2473}
2474
2475/// Notify other participants that a project has been updated.
2476async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2477    request: T,
2478    session: MessageContext,
2479) -> Result<()> {
2480    let project_id = ProjectId::from_proto(request.remote_entity_id());
2481    let project_connection_ids = session
2482        .db()
2483        .await
2484        .project_connection_ids(project_id, session.connection_id, false)
2485        .await?;
2486
2487    broadcast(
2488        Some(session.connection_id),
2489        project_connection_ids.iter().copied(),
2490        |connection_id| {
2491            session
2492                .peer
2493                .forward_send(session.connection_id, connection_id, request.clone())
2494        },
2495    );
2496    Ok(())
2497}
2498
2499/// Start following another user in a call.
2500async fn follow(
2501    request: proto::Follow,
2502    response: Response<proto::Follow>,
2503    session: MessageContext,
2504) -> Result<()> {
2505    let room_id = RoomId::from_proto(request.room_id);
2506    let project_id = request.project_id.map(ProjectId::from_proto);
2507    let leader_id = request.leader_id.context("invalid leader id")?.into();
2508    let follower_id = session.connection_id;
2509
2510    session
2511        .db()
2512        .await
2513        .check_room_participants(room_id, leader_id, session.connection_id)
2514        .await?;
2515
2516    let response_payload = session.forward_request(leader_id, request).await?;
2517    response.send(response_payload)?;
2518
2519    if let Some(project_id) = project_id {
2520        let room = session
2521            .db()
2522            .await
2523            .follow(room_id, project_id, leader_id, follower_id)
2524            .await?;
2525        room_updated(&room, &session.peer);
2526    }
2527
2528    Ok(())
2529}
2530
2531/// Stop following another user in a call.
2532async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2533    let room_id = RoomId::from_proto(request.room_id);
2534    let project_id = request.project_id.map(ProjectId::from_proto);
2535    let leader_id = request.leader_id.context("invalid leader id")?.into();
2536    let follower_id = session.connection_id;
2537
2538    session
2539        .db()
2540        .await
2541        .check_room_participants(room_id, leader_id, session.connection_id)
2542        .await?;
2543
2544    session
2545        .peer
2546        .forward_send(session.connection_id, leader_id, request)?;
2547
2548    if let Some(project_id) = project_id {
2549        let room = session
2550            .db()
2551            .await
2552            .unfollow(room_id, project_id, leader_id, follower_id)
2553            .await?;
2554        room_updated(&room, &session.peer);
2555    }
2556
2557    Ok(())
2558}
2559
2560/// Notify everyone following you of your current location.
2561async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2562    let room_id = RoomId::from_proto(request.room_id);
2563    let database = session.db.lock().await;
2564
2565    let connection_ids = if let Some(project_id) = request.project_id {
2566        let project_id = ProjectId::from_proto(project_id);
2567        database
2568            .project_connection_ids(project_id, session.connection_id, true)
2569            .await?
2570    } else {
2571        database
2572            .room_connection_ids(room_id, session.connection_id)
2573            .await?
2574    };
2575
2576    // For now, don't send view update messages back to that view's current leader.
2577    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2578        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2579        _ => None,
2580    });
2581
2582    for connection_id in connection_ids.iter().cloned() {
2583        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2584            session
2585                .peer
2586                .forward_send(session.connection_id, connection_id, request.clone())?;
2587        }
2588    }
2589    Ok(())
2590}
2591
2592/// Get public data about users.
2593async fn get_users(
2594    request: proto::GetUsers,
2595    response: Response<proto::GetUsers>,
2596    session: MessageContext,
2597) -> Result<()> {
2598    let user_ids = request
2599        .user_ids
2600        .into_iter()
2601        .map(UserId::from_proto)
2602        .collect();
2603    let users = session
2604        .db()
2605        .await
2606        .get_users_by_ids(user_ids)
2607        .await?
2608        .into_iter()
2609        .map(|user| proto::User {
2610            id: user.id.to_proto(),
2611            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2612            github_login: user.github_login,
2613            name: user.name,
2614        })
2615        .collect();
2616    response.send(proto::UsersResponse { users })?;
2617    Ok(())
2618}
2619
2620/// Search for users (to invite) buy Github login
2621async fn fuzzy_search_users(
2622    request: proto::FuzzySearchUsers,
2623    response: Response<proto::FuzzySearchUsers>,
2624    session: MessageContext,
2625) -> Result<()> {
2626    let query = request.query;
2627    let users = match query.len() {
2628        0 => vec![],
2629        1 | 2 => session
2630            .db()
2631            .await
2632            .get_user_by_github_login(&query)
2633            .await?
2634            .into_iter()
2635            .collect(),
2636        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2637    };
2638    let users = users
2639        .into_iter()
2640        .filter(|user| user.id != session.user_id())
2641        .map(|user| proto::User {
2642            id: user.id.to_proto(),
2643            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2644            github_login: user.github_login,
2645            name: user.name,
2646        })
2647        .collect();
2648    response.send(proto::UsersResponse { users })?;
2649    Ok(())
2650}
2651
2652/// Send a contact request to another user.
2653async fn request_contact(
2654    request: proto::RequestContact,
2655    response: Response<proto::RequestContact>,
2656    session: MessageContext,
2657) -> Result<()> {
2658    let requester_id = session.user_id();
2659    let responder_id = UserId::from_proto(request.responder_id);
2660    if requester_id == responder_id {
2661        return Err(anyhow!("cannot add yourself as a contact"))?;
2662    }
2663
2664    let notifications = session
2665        .db()
2666        .await
2667        .send_contact_request(requester_id, responder_id)
2668        .await?;
2669
2670    // Update outgoing contact requests of requester
2671    let mut update = proto::UpdateContacts::default();
2672    update.outgoing_requests.push(responder_id.to_proto());
2673    for connection_id in session
2674        .connection_pool()
2675        .await
2676        .user_connection_ids(requester_id)
2677    {
2678        session.peer.send(connection_id, update.clone())?;
2679    }
2680
2681    // Update incoming contact requests of responder
2682    let mut update = proto::UpdateContacts::default();
2683    update
2684        .incoming_requests
2685        .push(proto::IncomingContactRequest {
2686            requester_id: requester_id.to_proto(),
2687        });
2688    let connection_pool = session.connection_pool().await;
2689    for connection_id in connection_pool.user_connection_ids(responder_id) {
2690        session.peer.send(connection_id, update.clone())?;
2691    }
2692
2693    send_notifications(&connection_pool, &session.peer, notifications);
2694
2695    response.send(proto::Ack {})?;
2696    Ok(())
2697}
2698
2699/// Accept or decline a contact request
2700async fn respond_to_contact_request(
2701    request: proto::RespondToContactRequest,
2702    response: Response<proto::RespondToContactRequest>,
2703    session: MessageContext,
2704) -> Result<()> {
2705    let responder_id = session.user_id();
2706    let requester_id = UserId::from_proto(request.requester_id);
2707    let db = session.db().await;
2708    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2709        db.dismiss_contact_notification(responder_id, requester_id)
2710            .await?;
2711    } else {
2712        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2713
2714        let notifications = db
2715            .respond_to_contact_request(responder_id, requester_id, accept)
2716            .await?;
2717        let requester_busy = db.is_user_busy(requester_id).await?;
2718        let responder_busy = db.is_user_busy(responder_id).await?;
2719
2720        let pool = session.connection_pool().await;
2721        // Update responder with new contact
2722        let mut update = proto::UpdateContacts::default();
2723        if accept {
2724            update
2725                .contacts
2726                .push(contact_for_user(requester_id, requester_busy, &pool));
2727        }
2728        update
2729            .remove_incoming_requests
2730            .push(requester_id.to_proto());
2731        for connection_id in pool.user_connection_ids(responder_id) {
2732            session.peer.send(connection_id, update.clone())?;
2733        }
2734
2735        // Update requester with new contact
2736        let mut update = proto::UpdateContacts::default();
2737        if accept {
2738            update
2739                .contacts
2740                .push(contact_for_user(responder_id, responder_busy, &pool));
2741        }
2742        update
2743            .remove_outgoing_requests
2744            .push(responder_id.to_proto());
2745
2746        for connection_id in pool.user_connection_ids(requester_id) {
2747            session.peer.send(connection_id, update.clone())?;
2748        }
2749
2750        send_notifications(&pool, &session.peer, notifications);
2751    }
2752
2753    response.send(proto::Ack {})?;
2754    Ok(())
2755}
2756
2757/// Remove a contact.
2758async fn remove_contact(
2759    request: proto::RemoveContact,
2760    response: Response<proto::RemoveContact>,
2761    session: MessageContext,
2762) -> Result<()> {
2763    let requester_id = session.user_id();
2764    let responder_id = UserId::from_proto(request.user_id);
2765    let db = session.db().await;
2766    let (contact_accepted, deleted_notification_id) =
2767        db.remove_contact(requester_id, responder_id).await?;
2768
2769    let pool = session.connection_pool().await;
2770    // Update outgoing contact requests of requester
2771    let mut update = proto::UpdateContacts::default();
2772    if contact_accepted {
2773        update.remove_contacts.push(responder_id.to_proto());
2774    } else {
2775        update
2776            .remove_outgoing_requests
2777            .push(responder_id.to_proto());
2778    }
2779    for connection_id in pool.user_connection_ids(requester_id) {
2780        session.peer.send(connection_id, update.clone())?;
2781    }
2782
2783    // Update incoming contact requests of responder
2784    let mut update = proto::UpdateContacts::default();
2785    if contact_accepted {
2786        update.remove_contacts.push(requester_id.to_proto());
2787    } else {
2788        update
2789            .remove_incoming_requests
2790            .push(requester_id.to_proto());
2791    }
2792    for connection_id in pool.user_connection_ids(responder_id) {
2793        session.peer.send(connection_id, update.clone())?;
2794        if let Some(notification_id) = deleted_notification_id {
2795            session.peer.send(
2796                connection_id,
2797                proto::DeleteNotification {
2798                    notification_id: notification_id.to_proto(),
2799                },
2800            )?;
2801        }
2802    }
2803
2804    response.send(proto::Ack {})?;
2805    Ok(())
2806}
2807
2808fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2809    version.0.minor() < 139
2810}
2811
2812async fn subscribe_to_channels(
2813    _: proto::SubscribeToChannels,
2814    session: MessageContext,
2815) -> Result<()> {
2816    subscribe_user_to_channels(session.user_id(), &session).await?;
2817    Ok(())
2818}
2819
2820async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2821    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2822    let mut pool = session.connection_pool().await;
2823    for membership in &channels_for_user.channel_memberships {
2824        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2825    }
2826    session.peer.send(
2827        session.connection_id,
2828        build_update_user_channels(&channels_for_user),
2829    )?;
2830    session.peer.send(
2831        session.connection_id,
2832        build_channels_update(channels_for_user),
2833    )?;
2834    Ok(())
2835}
2836
2837/// Creates a new channel.
2838async fn create_channel(
2839    request: proto::CreateChannel,
2840    response: Response<proto::CreateChannel>,
2841    session: MessageContext,
2842) -> Result<()> {
2843    let db = session.db().await;
2844
2845    let parent_id = request.parent_id.map(ChannelId::from_proto);
2846    let (channel, membership) = db
2847        .create_channel(&request.name, parent_id, session.user_id())
2848        .await?;
2849
2850    let root_id = channel.root_id();
2851    let channel = Channel::from_model(channel);
2852
2853    response.send(proto::CreateChannelResponse {
2854        channel: Some(channel.to_proto()),
2855        parent_id: request.parent_id,
2856    })?;
2857
2858    let mut connection_pool = session.connection_pool().await;
2859    if let Some(membership) = membership {
2860        connection_pool.subscribe_to_channel(
2861            membership.user_id,
2862            membership.channel_id,
2863            membership.role,
2864        );
2865        let update = proto::UpdateUserChannels {
2866            channel_memberships: vec![proto::ChannelMembership {
2867                channel_id: membership.channel_id.to_proto(),
2868                role: membership.role.into(),
2869            }],
2870            ..Default::default()
2871        };
2872        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2873            session.peer.send(connection_id, update.clone())?;
2874        }
2875    }
2876
2877    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2878        if !role.can_see_channel(channel.visibility) {
2879            continue;
2880        }
2881
2882        let update = proto::UpdateChannels {
2883            channels: vec![channel.to_proto()],
2884            ..Default::default()
2885        };
2886        session.peer.send(connection_id, update.clone())?;
2887    }
2888
2889    Ok(())
2890}
2891
2892/// Delete a channel
2893async fn delete_channel(
2894    request: proto::DeleteChannel,
2895    response: Response<proto::DeleteChannel>,
2896    session: MessageContext,
2897) -> Result<()> {
2898    let db = session.db().await;
2899
2900    let channel_id = request.channel_id;
2901    let (root_channel, removed_channels) = db
2902        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2903        .await?;
2904    response.send(proto::Ack {})?;
2905
2906    // Notify members of removed channels
2907    let mut update = proto::UpdateChannels::default();
2908    update
2909        .delete_channels
2910        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2911
2912    let connection_pool = session.connection_pool().await;
2913    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2914        session.peer.send(connection_id, update.clone())?;
2915    }
2916
2917    Ok(())
2918}
2919
2920/// Invite someone to join a channel.
2921async fn invite_channel_member(
2922    request: proto::InviteChannelMember,
2923    response: Response<proto::InviteChannelMember>,
2924    session: MessageContext,
2925) -> Result<()> {
2926    let db = session.db().await;
2927    let channel_id = ChannelId::from_proto(request.channel_id);
2928    let invitee_id = UserId::from_proto(request.user_id);
2929    let InviteMemberResult {
2930        channel,
2931        notifications,
2932    } = db
2933        .invite_channel_member(
2934            channel_id,
2935            invitee_id,
2936            session.user_id(),
2937            request.role().into(),
2938        )
2939        .await?;
2940
2941    let update = proto::UpdateChannels {
2942        channel_invitations: vec![channel.to_proto()],
2943        ..Default::default()
2944    };
2945
2946    let connection_pool = session.connection_pool().await;
2947    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2948        session.peer.send(connection_id, update.clone())?;
2949    }
2950
2951    send_notifications(&connection_pool, &session.peer, notifications);
2952
2953    response.send(proto::Ack {})?;
2954    Ok(())
2955}
2956
2957/// remove someone from a channel
2958async fn remove_channel_member(
2959    request: proto::RemoveChannelMember,
2960    response: Response<proto::RemoveChannelMember>,
2961    session: MessageContext,
2962) -> Result<()> {
2963    let db = session.db().await;
2964    let channel_id = ChannelId::from_proto(request.channel_id);
2965    let member_id = UserId::from_proto(request.user_id);
2966
2967    let RemoveChannelMemberResult {
2968        membership_update,
2969        notification_id,
2970    } = db
2971        .remove_channel_member(channel_id, member_id, session.user_id())
2972        .await?;
2973
2974    let mut connection_pool = session.connection_pool().await;
2975    notify_membership_updated(
2976        &mut connection_pool,
2977        membership_update,
2978        member_id,
2979        &session.peer,
2980    );
2981    for connection_id in connection_pool.user_connection_ids(member_id) {
2982        if let Some(notification_id) = notification_id {
2983            session
2984                .peer
2985                .send(
2986                    connection_id,
2987                    proto::DeleteNotification {
2988                        notification_id: notification_id.to_proto(),
2989                    },
2990                )
2991                .trace_err();
2992        }
2993    }
2994
2995    response.send(proto::Ack {})?;
2996    Ok(())
2997}
2998
2999/// Toggle the channel between public and private.
3000/// Care is taken to maintain the invariant that public channels only descend from public channels,
3001/// (though members-only channels can appear at any point in the hierarchy).
3002async fn set_channel_visibility(
3003    request: proto::SetChannelVisibility,
3004    response: Response<proto::SetChannelVisibility>,
3005    session: MessageContext,
3006) -> Result<()> {
3007    let db = session.db().await;
3008    let channel_id = ChannelId::from_proto(request.channel_id);
3009    let visibility = request.visibility().into();
3010
3011    let channel_model = db
3012        .set_channel_visibility(channel_id, visibility, session.user_id())
3013        .await?;
3014    let root_id = channel_model.root_id();
3015    let channel = Channel::from_model(channel_model);
3016
3017    let mut connection_pool = session.connection_pool().await;
3018    for (user_id, role) in connection_pool
3019        .channel_user_ids(root_id)
3020        .collect::<Vec<_>>()
3021        .into_iter()
3022    {
3023        let update = if role.can_see_channel(channel.visibility) {
3024            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3025            proto::UpdateChannels {
3026                channels: vec![channel.to_proto()],
3027                ..Default::default()
3028            }
3029        } else {
3030            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3031            proto::UpdateChannels {
3032                delete_channels: vec![channel.id.to_proto()],
3033                ..Default::default()
3034            }
3035        };
3036
3037        for connection_id in connection_pool.user_connection_ids(user_id) {
3038            session.peer.send(connection_id, update.clone())?;
3039        }
3040    }
3041
3042    response.send(proto::Ack {})?;
3043    Ok(())
3044}
3045
3046/// Alter the role for a user in the channel.
3047async fn set_channel_member_role(
3048    request: proto::SetChannelMemberRole,
3049    response: Response<proto::SetChannelMemberRole>,
3050    session: MessageContext,
3051) -> Result<()> {
3052    let db = session.db().await;
3053    let channel_id = ChannelId::from_proto(request.channel_id);
3054    let member_id = UserId::from_proto(request.user_id);
3055    let result = db
3056        .set_channel_member_role(
3057            channel_id,
3058            session.user_id(),
3059            member_id,
3060            request.role().into(),
3061        )
3062        .await?;
3063
3064    match result {
3065        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3066            let mut connection_pool = session.connection_pool().await;
3067            notify_membership_updated(
3068                &mut connection_pool,
3069                membership_update,
3070                member_id,
3071                &session.peer,
3072            )
3073        }
3074        db::SetMemberRoleResult::InviteUpdated(channel) => {
3075            let update = proto::UpdateChannels {
3076                channel_invitations: vec![channel.to_proto()],
3077                ..Default::default()
3078            };
3079
3080            for connection_id in session
3081                .connection_pool()
3082                .await
3083                .user_connection_ids(member_id)
3084            {
3085                session.peer.send(connection_id, update.clone())?;
3086            }
3087        }
3088    }
3089
3090    response.send(proto::Ack {})?;
3091    Ok(())
3092}
3093
3094/// Change the name of a channel
3095async fn rename_channel(
3096    request: proto::RenameChannel,
3097    response: Response<proto::RenameChannel>,
3098    session: MessageContext,
3099) -> Result<()> {
3100    let db = session.db().await;
3101    let channel_id = ChannelId::from_proto(request.channel_id);
3102    let channel_model = db
3103        .rename_channel(channel_id, session.user_id(), &request.name)
3104        .await?;
3105    let root_id = channel_model.root_id();
3106    let channel = Channel::from_model(channel_model);
3107
3108    response.send(proto::RenameChannelResponse {
3109        channel: Some(channel.to_proto()),
3110    })?;
3111
3112    let connection_pool = session.connection_pool().await;
3113    let update = proto::UpdateChannels {
3114        channels: vec![channel.to_proto()],
3115        ..Default::default()
3116    };
3117    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3118        if role.can_see_channel(channel.visibility) {
3119            session.peer.send(connection_id, update.clone())?;
3120        }
3121    }
3122
3123    Ok(())
3124}
3125
3126/// Move a channel to a new parent.
3127async fn move_channel(
3128    request: proto::MoveChannel,
3129    response: Response<proto::MoveChannel>,
3130    session: MessageContext,
3131) -> Result<()> {
3132    let channel_id = ChannelId::from_proto(request.channel_id);
3133    let to = ChannelId::from_proto(request.to);
3134
3135    let (root_id, channels) = session
3136        .db()
3137        .await
3138        .move_channel(channel_id, to, session.user_id())
3139        .await?;
3140
3141    let connection_pool = session.connection_pool().await;
3142    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3143        let channels = channels
3144            .iter()
3145            .filter_map(|channel| {
3146                if role.can_see_channel(channel.visibility) {
3147                    Some(channel.to_proto())
3148                } else {
3149                    None
3150                }
3151            })
3152            .collect::<Vec<_>>();
3153        if channels.is_empty() {
3154            continue;
3155        }
3156
3157        let update = proto::UpdateChannels {
3158            channels,
3159            ..Default::default()
3160        };
3161
3162        session.peer.send(connection_id, update.clone())?;
3163    }
3164
3165    response.send(Ack {})?;
3166    Ok(())
3167}
3168
3169async fn reorder_channel(
3170    request: proto::ReorderChannel,
3171    response: Response<proto::ReorderChannel>,
3172    session: MessageContext,
3173) -> Result<()> {
3174    let channel_id = ChannelId::from_proto(request.channel_id);
3175    let direction = request.direction();
3176
3177    let updated_channels = session
3178        .db()
3179        .await
3180        .reorder_channel(channel_id, direction, session.user_id())
3181        .await?;
3182
3183    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3184        let connection_pool = session.connection_pool().await;
3185        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3186            let channels = updated_channels
3187                .iter()
3188                .filter_map(|channel| {
3189                    if role.can_see_channel(channel.visibility) {
3190                        Some(channel.to_proto())
3191                    } else {
3192                        None
3193                    }
3194                })
3195                .collect::<Vec<_>>();
3196
3197            if channels.is_empty() {
3198                continue;
3199            }
3200
3201            let update = proto::UpdateChannels {
3202                channels,
3203                ..Default::default()
3204            };
3205
3206            session.peer.send(connection_id, update.clone())?;
3207        }
3208    }
3209
3210    response.send(Ack {})?;
3211    Ok(())
3212}
3213
3214/// Get the list of channel members
3215async fn get_channel_members(
3216    request: proto::GetChannelMembers,
3217    response: Response<proto::GetChannelMembers>,
3218    session: MessageContext,
3219) -> Result<()> {
3220    let db = session.db().await;
3221    let channel_id = ChannelId::from_proto(request.channel_id);
3222    let limit = if request.limit == 0 {
3223        u16::MAX as u64
3224    } else {
3225        request.limit
3226    };
3227    let (members, users) = db
3228        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3229        .await?;
3230    response.send(proto::GetChannelMembersResponse { members, users })?;
3231    Ok(())
3232}
3233
3234/// Accept or decline a channel invitation.
3235async fn respond_to_channel_invite(
3236    request: proto::RespondToChannelInvite,
3237    response: Response<proto::RespondToChannelInvite>,
3238    session: MessageContext,
3239) -> Result<()> {
3240    let db = session.db().await;
3241    let channel_id = ChannelId::from_proto(request.channel_id);
3242    let RespondToChannelInvite {
3243        membership_update,
3244        notifications,
3245    } = db
3246        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3247        .await?;
3248
3249    let mut connection_pool = session.connection_pool().await;
3250    if let Some(membership_update) = membership_update {
3251        notify_membership_updated(
3252            &mut connection_pool,
3253            membership_update,
3254            session.user_id(),
3255            &session.peer,
3256        );
3257    } else {
3258        let update = proto::UpdateChannels {
3259            remove_channel_invitations: vec![channel_id.to_proto()],
3260            ..Default::default()
3261        };
3262
3263        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3264            session.peer.send(connection_id, update.clone())?;
3265        }
3266    };
3267
3268    send_notifications(&connection_pool, &session.peer, notifications);
3269
3270    response.send(proto::Ack {})?;
3271
3272    Ok(())
3273}
3274
3275/// Join the channels' room
3276async fn join_channel(
3277    request: proto::JoinChannel,
3278    response: Response<proto::JoinChannel>,
3279    session: MessageContext,
3280) -> Result<()> {
3281    let channel_id = ChannelId::from_proto(request.channel_id);
3282    join_channel_internal(channel_id, Box::new(response), session).await
3283}
3284
3285trait JoinChannelInternalResponse {
3286    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3287}
3288impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3289    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3290        Response::<proto::JoinChannel>::send(self, result)
3291    }
3292}
3293impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3294    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3295        Response::<proto::JoinRoom>::send(self, result)
3296    }
3297}
3298
3299async fn join_channel_internal(
3300    channel_id: ChannelId,
3301    response: Box<impl JoinChannelInternalResponse>,
3302    session: MessageContext,
3303) -> Result<()> {
3304    let joined_room = {
3305        let mut db = session.db().await;
3306        // If zed quits without leaving the room, and the user re-opens zed before the
3307        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3308        // room they were in.
3309        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3310            tracing::info!(
3311                stale_connection_id = %connection,
3312                "cleaning up stale connection",
3313            );
3314            drop(db);
3315            leave_room_for_session(&session, connection).await?;
3316            db = session.db().await;
3317        }
3318
3319        let (joined_room, membership_updated, role) = db
3320            .join_channel(channel_id, session.user_id(), session.connection_id)
3321            .await?;
3322
3323        let live_kit_connection_info =
3324            session
3325                .app_state
3326                .livekit_client
3327                .as_ref()
3328                .and_then(|live_kit| {
3329                    let (can_publish, token) = if role == ChannelRole::Guest {
3330                        (
3331                            false,
3332                            live_kit
3333                                .guest_token(
3334                                    &joined_room.room.livekit_room,
3335                                    &session.user_id().to_string(),
3336                                )
3337                                .trace_err()?,
3338                        )
3339                    } else {
3340                        (
3341                            true,
3342                            live_kit
3343                                .room_token(
3344                                    &joined_room.room.livekit_room,
3345                                    &session.user_id().to_string(),
3346                                )
3347                                .trace_err()?,
3348                        )
3349                    };
3350
3351                    Some(LiveKitConnectionInfo {
3352                        server_url: live_kit.url().into(),
3353                        token,
3354                        can_publish,
3355                    })
3356                });
3357
3358        response.send(proto::JoinRoomResponse {
3359            room: Some(joined_room.room.clone()),
3360            channel_id: joined_room
3361                .channel
3362                .as_ref()
3363                .map(|channel| channel.id.to_proto()),
3364            live_kit_connection_info,
3365        })?;
3366
3367        let mut connection_pool = session.connection_pool().await;
3368        if let Some(membership_updated) = membership_updated {
3369            notify_membership_updated(
3370                &mut connection_pool,
3371                membership_updated,
3372                session.user_id(),
3373                &session.peer,
3374            );
3375        }
3376
3377        room_updated(&joined_room.room, &session.peer);
3378
3379        joined_room
3380    };
3381
3382    channel_updated(
3383        &joined_room.channel.context("channel not returned")?,
3384        &joined_room.room,
3385        &session.peer,
3386        &*session.connection_pool().await,
3387    );
3388
3389    update_user_contacts(session.user_id(), &session).await?;
3390    Ok(())
3391}
3392
3393/// Start editing the channel notes
3394async fn join_channel_buffer(
3395    request: proto::JoinChannelBuffer,
3396    response: Response<proto::JoinChannelBuffer>,
3397    session: MessageContext,
3398) -> Result<()> {
3399    let db = session.db().await;
3400    let channel_id = ChannelId::from_proto(request.channel_id);
3401
3402    let open_response = db
3403        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3404        .await?;
3405
3406    let collaborators = open_response.collaborators.clone();
3407    response.send(open_response)?;
3408
3409    let update = UpdateChannelBufferCollaborators {
3410        channel_id: channel_id.to_proto(),
3411        collaborators: collaborators.clone(),
3412    };
3413    channel_buffer_updated(
3414        session.connection_id,
3415        collaborators
3416            .iter()
3417            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3418        &update,
3419        &session.peer,
3420    );
3421
3422    Ok(())
3423}
3424
3425/// Edit the channel notes
3426async fn update_channel_buffer(
3427    request: proto::UpdateChannelBuffer,
3428    session: MessageContext,
3429) -> Result<()> {
3430    let db = session.db().await;
3431    let channel_id = ChannelId::from_proto(request.channel_id);
3432
3433    let (collaborators, epoch, version) = db
3434        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3435        .await?;
3436
3437    channel_buffer_updated(
3438        session.connection_id,
3439        collaborators.clone(),
3440        &proto::UpdateChannelBuffer {
3441            channel_id: channel_id.to_proto(),
3442            operations: request.operations,
3443        },
3444        &session.peer,
3445    );
3446
3447    let pool = &*session.connection_pool().await;
3448
3449    let non_collaborators =
3450        pool.channel_connection_ids(channel_id)
3451            .filter_map(|(connection_id, _)| {
3452                if collaborators.contains(&connection_id) {
3453                    None
3454                } else {
3455                    Some(connection_id)
3456                }
3457            });
3458
3459    broadcast(None, non_collaborators, |peer_id| {
3460        session.peer.send(
3461            peer_id,
3462            proto::UpdateChannels {
3463                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3464                    channel_id: channel_id.to_proto(),
3465                    epoch: epoch as u64,
3466                    version: version.clone(),
3467                }],
3468                ..Default::default()
3469            },
3470        )
3471    });
3472
3473    Ok(())
3474}
3475
3476/// Rejoin the channel notes after a connection blip
3477async fn rejoin_channel_buffers(
3478    request: proto::RejoinChannelBuffers,
3479    response: Response<proto::RejoinChannelBuffers>,
3480    session: MessageContext,
3481) -> Result<()> {
3482    let db = session.db().await;
3483    let buffers = db
3484        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3485        .await?;
3486
3487    for rejoined_buffer in &buffers {
3488        let collaborators_to_notify = rejoined_buffer
3489            .buffer
3490            .collaborators
3491            .iter()
3492            .filter_map(|c| Some(c.peer_id?.into()));
3493        channel_buffer_updated(
3494            session.connection_id,
3495            collaborators_to_notify,
3496            &proto::UpdateChannelBufferCollaborators {
3497                channel_id: rejoined_buffer.buffer.channel_id,
3498                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3499            },
3500            &session.peer,
3501        );
3502    }
3503
3504    response.send(proto::RejoinChannelBuffersResponse {
3505        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3506    })?;
3507
3508    Ok(())
3509}
3510
3511/// Stop editing the channel notes
3512async fn leave_channel_buffer(
3513    request: proto::LeaveChannelBuffer,
3514    response: Response<proto::LeaveChannelBuffer>,
3515    session: MessageContext,
3516) -> Result<()> {
3517    let db = session.db().await;
3518    let channel_id = ChannelId::from_proto(request.channel_id);
3519
3520    let left_buffer = db
3521        .leave_channel_buffer(channel_id, session.connection_id)
3522        .await?;
3523
3524    response.send(Ack {})?;
3525
3526    channel_buffer_updated(
3527        session.connection_id,
3528        left_buffer.connections,
3529        &proto::UpdateChannelBufferCollaborators {
3530            channel_id: channel_id.to_proto(),
3531            collaborators: left_buffer.collaborators,
3532        },
3533        &session.peer,
3534    );
3535
3536    Ok(())
3537}
3538
3539fn channel_buffer_updated<T: EnvelopedMessage>(
3540    sender_id: ConnectionId,
3541    collaborators: impl IntoIterator<Item = ConnectionId>,
3542    message: &T,
3543    peer: &Peer,
3544) {
3545    broadcast(Some(sender_id), collaborators, |peer_id| {
3546        peer.send(peer_id, message.clone())
3547    });
3548}
3549
3550fn send_notifications(
3551    connection_pool: &ConnectionPool,
3552    peer: &Peer,
3553    notifications: db::NotificationBatch,
3554) {
3555    for (user_id, notification) in notifications {
3556        for connection_id in connection_pool.user_connection_ids(user_id) {
3557            if let Err(error) = peer.send(
3558                connection_id,
3559                proto::AddNotification {
3560                    notification: Some(notification.clone()),
3561                },
3562            ) {
3563                tracing::error!(
3564                    "failed to send notification to {:?} {}",
3565                    connection_id,
3566                    error
3567                );
3568            }
3569        }
3570    }
3571}
3572
3573/// Send a message to the channel
3574async fn send_channel_message(
3575    request: proto::SendChannelMessage,
3576    response: Response<proto::SendChannelMessage>,
3577    session: MessageContext,
3578) -> Result<()> {
3579    // Validate the message body.
3580    let body = request.body.trim().to_string();
3581    if body.len() > MAX_MESSAGE_LEN {
3582        return Err(anyhow!("message is too long"))?;
3583    }
3584    if body.is_empty() {
3585        return Err(anyhow!("message can't be blank"))?;
3586    }
3587
3588    // TODO: adjust mentions if body is trimmed
3589
3590    let timestamp = OffsetDateTime::now_utc();
3591    let nonce = request.nonce.context("nonce can't be blank")?;
3592
3593    let channel_id = ChannelId::from_proto(request.channel_id);
3594    let CreatedChannelMessage {
3595        message_id,
3596        participant_connection_ids,
3597        notifications,
3598    } = session
3599        .db()
3600        .await
3601        .create_channel_message(
3602            channel_id,
3603            session.user_id(),
3604            &body,
3605            &request.mentions,
3606            timestamp,
3607            nonce.clone().into(),
3608            request.reply_to_message_id.map(MessageId::from_proto),
3609        )
3610        .await?;
3611
3612    let message = proto::ChannelMessage {
3613        sender_id: session.user_id().to_proto(),
3614        id: message_id.to_proto(),
3615        body,
3616        mentions: request.mentions,
3617        timestamp: timestamp.unix_timestamp() as u64,
3618        nonce: Some(nonce),
3619        reply_to_message_id: request.reply_to_message_id,
3620        edited_at: None,
3621    };
3622    broadcast(
3623        Some(session.connection_id),
3624        participant_connection_ids.clone(),
3625        |connection| {
3626            session.peer.send(
3627                connection,
3628                proto::ChannelMessageSent {
3629                    channel_id: channel_id.to_proto(),
3630                    message: Some(message.clone()),
3631                },
3632            )
3633        },
3634    );
3635    response.send(proto::SendChannelMessageResponse {
3636        message: Some(message),
3637    })?;
3638
3639    let pool = &*session.connection_pool().await;
3640    let non_participants =
3641        pool.channel_connection_ids(channel_id)
3642            .filter_map(|(connection_id, _)| {
3643                if participant_connection_ids.contains(&connection_id) {
3644                    None
3645                } else {
3646                    Some(connection_id)
3647                }
3648            });
3649    broadcast(None, non_participants, |peer_id| {
3650        session.peer.send(
3651            peer_id,
3652            proto::UpdateChannels {
3653                latest_channel_message_ids: vec![proto::ChannelMessageId {
3654                    channel_id: channel_id.to_proto(),
3655                    message_id: message_id.to_proto(),
3656                }],
3657                ..Default::default()
3658            },
3659        )
3660    });
3661    send_notifications(pool, &session.peer, notifications);
3662
3663    Ok(())
3664}
3665
3666/// Delete a channel message
3667async fn remove_channel_message(
3668    request: proto::RemoveChannelMessage,
3669    response: Response<proto::RemoveChannelMessage>,
3670    session: MessageContext,
3671) -> Result<()> {
3672    let channel_id = ChannelId::from_proto(request.channel_id);
3673    let message_id = MessageId::from_proto(request.message_id);
3674    let (connection_ids, existing_notification_ids) = session
3675        .db()
3676        .await
3677        .remove_channel_message(channel_id, message_id, session.user_id())
3678        .await?;
3679
3680    broadcast(
3681        Some(session.connection_id),
3682        connection_ids,
3683        move |connection| {
3684            session.peer.send(connection, request.clone())?;
3685
3686            for notification_id in &existing_notification_ids {
3687                session.peer.send(
3688                    connection,
3689                    proto::DeleteNotification {
3690                        notification_id: (*notification_id).to_proto(),
3691                    },
3692                )?;
3693            }
3694
3695            Ok(())
3696        },
3697    );
3698    response.send(proto::Ack {})?;
3699    Ok(())
3700}
3701
3702async fn update_channel_message(
3703    request: proto::UpdateChannelMessage,
3704    response: Response<proto::UpdateChannelMessage>,
3705    session: MessageContext,
3706) -> Result<()> {
3707    let channel_id = ChannelId::from_proto(request.channel_id);
3708    let message_id = MessageId::from_proto(request.message_id);
3709    let updated_at = OffsetDateTime::now_utc();
3710    let UpdatedChannelMessage {
3711        message_id,
3712        participant_connection_ids,
3713        notifications,
3714        reply_to_message_id,
3715        timestamp,
3716        deleted_mention_notification_ids,
3717        updated_mention_notifications,
3718    } = session
3719        .db()
3720        .await
3721        .update_channel_message(
3722            channel_id,
3723            message_id,
3724            session.user_id(),
3725            request.body.as_str(),
3726            &request.mentions,
3727            updated_at,
3728        )
3729        .await?;
3730
3731    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3732
3733    let message = proto::ChannelMessage {
3734        sender_id: session.user_id().to_proto(),
3735        id: message_id.to_proto(),
3736        body: request.body.clone(),
3737        mentions: request.mentions.clone(),
3738        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3739        nonce: Some(nonce),
3740        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3741        edited_at: Some(updated_at.unix_timestamp() as u64),
3742    };
3743
3744    response.send(proto::Ack {})?;
3745
3746    let pool = &*session.connection_pool().await;
3747    broadcast(
3748        Some(session.connection_id),
3749        participant_connection_ids,
3750        |connection| {
3751            session.peer.send(
3752                connection,
3753                proto::ChannelMessageUpdate {
3754                    channel_id: channel_id.to_proto(),
3755                    message: Some(message.clone()),
3756                },
3757            )?;
3758
3759            for notification_id in &deleted_mention_notification_ids {
3760                session.peer.send(
3761                    connection,
3762                    proto::DeleteNotification {
3763                        notification_id: (*notification_id).to_proto(),
3764                    },
3765                )?;
3766            }
3767
3768            for notification in &updated_mention_notifications {
3769                session.peer.send(
3770                    connection,
3771                    proto::UpdateNotification {
3772                        notification: Some(notification.clone()),
3773                    },
3774                )?;
3775            }
3776
3777            Ok(())
3778        },
3779    );
3780
3781    send_notifications(pool, &session.peer, notifications);
3782
3783    Ok(())
3784}
3785
3786/// Mark a channel message as read
3787async fn acknowledge_channel_message(
3788    request: proto::AckChannelMessage,
3789    session: MessageContext,
3790) -> Result<()> {
3791    let channel_id = ChannelId::from_proto(request.channel_id);
3792    let message_id = MessageId::from_proto(request.message_id);
3793    let notifications = session
3794        .db()
3795        .await
3796        .observe_channel_message(channel_id, session.user_id(), message_id)
3797        .await?;
3798    send_notifications(
3799        &*session.connection_pool().await,
3800        &session.peer,
3801        notifications,
3802    );
3803    Ok(())
3804}
3805
3806/// Mark a buffer version as synced
3807async fn acknowledge_buffer_version(
3808    request: proto::AckBufferOperation,
3809    session: MessageContext,
3810) -> Result<()> {
3811    let buffer_id = BufferId::from_proto(request.buffer_id);
3812    session
3813        .db()
3814        .await
3815        .observe_buffer_version(
3816            buffer_id,
3817            session.user_id(),
3818            request.epoch as i32,
3819            &request.version,
3820        )
3821        .await?;
3822    Ok(())
3823}
3824
3825/// Get a Supermaven API key for the user
3826async fn get_supermaven_api_key(
3827    _request: proto::GetSupermavenApiKey,
3828    response: Response<proto::GetSupermavenApiKey>,
3829    session: MessageContext,
3830) -> Result<()> {
3831    let user_id: String = session.user_id().to_string();
3832    if !session.is_staff() {
3833        return Err(anyhow!("supermaven not enabled for this account"))?;
3834    }
3835
3836    let email = session.email().context("user must have an email")?;
3837
3838    let supermaven_admin_api = session
3839        .supermaven_client
3840        .as_ref()
3841        .context("supermaven not configured")?;
3842
3843    let result = supermaven_admin_api
3844        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3845        .await?;
3846
3847    response.send(proto::GetSupermavenApiKeyResponse {
3848        api_key: result.api_key,
3849    })?;
3850
3851    Ok(())
3852}
3853
3854/// Start receiving chat updates for a channel
3855async fn join_channel_chat(
3856    request: proto::JoinChannelChat,
3857    response: Response<proto::JoinChannelChat>,
3858    session: MessageContext,
3859) -> Result<()> {
3860    let channel_id = ChannelId::from_proto(request.channel_id);
3861
3862    let db = session.db().await;
3863    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3864        .await?;
3865    let messages = db
3866        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3867        .await?;
3868    response.send(proto::JoinChannelChatResponse {
3869        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3870        messages,
3871    })?;
3872    Ok(())
3873}
3874
3875/// Stop receiving chat updates for a channel
3876async fn leave_channel_chat(
3877    request: proto::LeaveChannelChat,
3878    session: MessageContext,
3879) -> Result<()> {
3880    let channel_id = ChannelId::from_proto(request.channel_id);
3881    session
3882        .db()
3883        .await
3884        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3885        .await?;
3886    Ok(())
3887}
3888
3889/// Retrieve the chat history for a channel
3890async fn get_channel_messages(
3891    request: proto::GetChannelMessages,
3892    response: Response<proto::GetChannelMessages>,
3893    session: MessageContext,
3894) -> Result<()> {
3895    let channel_id = ChannelId::from_proto(request.channel_id);
3896    let messages = session
3897        .db()
3898        .await
3899        .get_channel_messages(
3900            channel_id,
3901            session.user_id(),
3902            MESSAGE_COUNT_PER_PAGE,
3903            Some(MessageId::from_proto(request.before_message_id)),
3904        )
3905        .await?;
3906    response.send(proto::GetChannelMessagesResponse {
3907        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3908        messages,
3909    })?;
3910    Ok(())
3911}
3912
3913/// Retrieve specific chat messages
3914async fn get_channel_messages_by_id(
3915    request: proto::GetChannelMessagesById,
3916    response: Response<proto::GetChannelMessagesById>,
3917    session: MessageContext,
3918) -> Result<()> {
3919    let message_ids = request
3920        .message_ids
3921        .iter()
3922        .map(|id| MessageId::from_proto(*id))
3923        .collect::<Vec<_>>();
3924    let messages = session
3925        .db()
3926        .await
3927        .get_channel_messages_by_id(session.user_id(), &message_ids)
3928        .await?;
3929    response.send(proto::GetChannelMessagesResponse {
3930        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3931        messages,
3932    })?;
3933    Ok(())
3934}
3935
3936/// Retrieve the current users notifications
3937async fn get_notifications(
3938    request: proto::GetNotifications,
3939    response: Response<proto::GetNotifications>,
3940    session: MessageContext,
3941) -> Result<()> {
3942    let notifications = session
3943        .db()
3944        .await
3945        .get_notifications(
3946            session.user_id(),
3947            NOTIFICATION_COUNT_PER_PAGE,
3948            request.before_id.map(db::NotificationId::from_proto),
3949        )
3950        .await?;
3951    response.send(proto::GetNotificationsResponse {
3952        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3953        notifications,
3954    })?;
3955    Ok(())
3956}
3957
3958/// Mark notifications as read
3959async fn mark_notification_as_read(
3960    request: proto::MarkNotificationRead,
3961    response: Response<proto::MarkNotificationRead>,
3962    session: MessageContext,
3963) -> Result<()> {
3964    let database = &session.db().await;
3965    let notifications = database
3966        .mark_notification_as_read_by_id(
3967            session.user_id(),
3968            NotificationId::from_proto(request.notification_id),
3969        )
3970        .await?;
3971    send_notifications(
3972        &*session.connection_pool().await,
3973        &session.peer,
3974        notifications,
3975    );
3976    response.send(proto::Ack {})?;
3977    Ok(())
3978}
3979
3980fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3981    let message = match message {
3982        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3983        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3984        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3985        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3986        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3987            code: frame.code.into(),
3988            reason: frame.reason.as_str().to_owned().into(),
3989        })),
3990        // We should never receive a frame while reading the message, according
3991        // to the `tungstenite` maintainers:
3992        //
3993        // > It cannot occur when you read messages from the WebSocket, but it
3994        // > can be used when you want to send the raw frames (e.g. you want to
3995        // > send the frames to the WebSocket without composing the full message first).
3996        // >
3997        // > — https://github.com/snapview/tungstenite-rs/issues/268
3998        TungsteniteMessage::Frame(_) => {
3999            bail!("received an unexpected frame while reading the message")
4000        }
4001    };
4002
4003    Ok(message)
4004}
4005
4006fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4007    match message {
4008        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4009        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4010        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4011        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4012        AxumMessage::Close(frame) => {
4013            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4014                code: frame.code.into(),
4015                reason: frame.reason.as_ref().into(),
4016            }))
4017        }
4018    }
4019}
4020
4021fn notify_membership_updated(
4022    connection_pool: &mut ConnectionPool,
4023    result: MembershipUpdated,
4024    user_id: UserId,
4025    peer: &Peer,
4026) {
4027    for membership in &result.new_channels.channel_memberships {
4028        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4029    }
4030    for channel_id in &result.removed_channels {
4031        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4032    }
4033
4034    let user_channels_update = proto::UpdateUserChannels {
4035        channel_memberships: result
4036            .new_channels
4037            .channel_memberships
4038            .iter()
4039            .map(|cm| proto::ChannelMembership {
4040                channel_id: cm.channel_id.to_proto(),
4041                role: cm.role.into(),
4042            })
4043            .collect(),
4044        ..Default::default()
4045    };
4046
4047    let mut update = build_channels_update(result.new_channels);
4048    update.delete_channels = result
4049        .removed_channels
4050        .into_iter()
4051        .map(|id| id.to_proto())
4052        .collect();
4053    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4054
4055    for connection_id in connection_pool.user_connection_ids(user_id) {
4056        peer.send(connection_id, user_channels_update.clone())
4057            .trace_err();
4058        peer.send(connection_id, update.clone()).trace_err();
4059    }
4060}
4061
4062fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4063    proto::UpdateUserChannels {
4064        channel_memberships: channels
4065            .channel_memberships
4066            .iter()
4067            .map(|m| proto::ChannelMembership {
4068                channel_id: m.channel_id.to_proto(),
4069                role: m.role.into(),
4070            })
4071            .collect(),
4072        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4073        observed_channel_message_id: channels.observed_channel_messages.clone(),
4074    }
4075}
4076
4077fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4078    let mut update = proto::UpdateChannels::default();
4079
4080    for channel in channels.channels {
4081        update.channels.push(channel.to_proto());
4082    }
4083
4084    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4085    update.latest_channel_message_ids = channels.latest_channel_messages;
4086
4087    for (channel_id, participants) in channels.channel_participants {
4088        update
4089            .channel_participants
4090            .push(proto::ChannelParticipants {
4091                channel_id: channel_id.to_proto(),
4092                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4093            });
4094    }
4095
4096    for channel in channels.invited_channels {
4097        update.channel_invitations.push(channel.to_proto());
4098    }
4099
4100    update
4101}
4102
4103fn build_initial_contacts_update(
4104    contacts: Vec<db::Contact>,
4105    pool: &ConnectionPool,
4106) -> proto::UpdateContacts {
4107    let mut update = proto::UpdateContacts::default();
4108
4109    for contact in contacts {
4110        match contact {
4111            db::Contact::Accepted { user_id, busy } => {
4112                update.contacts.push(contact_for_user(user_id, busy, pool));
4113            }
4114            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4115            db::Contact::Incoming { user_id } => {
4116                update
4117                    .incoming_requests
4118                    .push(proto::IncomingContactRequest {
4119                        requester_id: user_id.to_proto(),
4120                    })
4121            }
4122        }
4123    }
4124
4125    update
4126}
4127
4128fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4129    proto::Contact {
4130        user_id: user_id.to_proto(),
4131        online: pool.is_user_online(user_id),
4132        busy,
4133    }
4134}
4135
4136fn room_updated(room: &proto::Room, peer: &Peer) {
4137    broadcast(
4138        None,
4139        room.participants
4140            .iter()
4141            .filter_map(|participant| Some(participant.peer_id?.into())),
4142        |peer_id| {
4143            peer.send(
4144                peer_id,
4145                proto::RoomUpdated {
4146                    room: Some(room.clone()),
4147                },
4148            )
4149        },
4150    );
4151}
4152
4153fn channel_updated(
4154    channel: &db::channel::Model,
4155    room: &proto::Room,
4156    peer: &Peer,
4157    pool: &ConnectionPool,
4158) {
4159    let participants = room
4160        .participants
4161        .iter()
4162        .map(|p| p.user_id)
4163        .collect::<Vec<_>>();
4164
4165    broadcast(
4166        None,
4167        pool.channel_connection_ids(channel.root_id())
4168            .filter_map(|(channel_id, role)| {
4169                role.can_see_channel(channel.visibility)
4170                    .then_some(channel_id)
4171            }),
4172        |peer_id| {
4173            peer.send(
4174                peer_id,
4175                proto::UpdateChannels {
4176                    channel_participants: vec![proto::ChannelParticipants {
4177                        channel_id: channel.id.to_proto(),
4178                        participant_user_ids: participants.clone(),
4179                    }],
4180                    ..Default::default()
4181                },
4182            )
4183        },
4184    );
4185}
4186
4187async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4188    let db = session.db().await;
4189
4190    let contacts = db.get_contacts(user_id).await?;
4191    let busy = db.is_user_busy(user_id).await?;
4192
4193    let pool = session.connection_pool().await;
4194    let updated_contact = contact_for_user(user_id, busy, &pool);
4195    for contact in contacts {
4196        if let db::Contact::Accepted {
4197            user_id: contact_user_id,
4198            ..
4199        } = contact
4200        {
4201            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4202                session
4203                    .peer
4204                    .send(
4205                        contact_conn_id,
4206                        proto::UpdateContacts {
4207                            contacts: vec![updated_contact.clone()],
4208                            remove_contacts: Default::default(),
4209                            incoming_requests: Default::default(),
4210                            remove_incoming_requests: Default::default(),
4211                            outgoing_requests: Default::default(),
4212                            remove_outgoing_requests: Default::default(),
4213                        },
4214                    )
4215                    .trace_err();
4216            }
4217        }
4218    }
4219    Ok(())
4220}
4221
4222async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4223    let mut contacts_to_update = HashSet::default();
4224
4225    let room_id;
4226    let canceled_calls_to_user_ids;
4227    let livekit_room;
4228    let delete_livekit_room;
4229    let room;
4230    let channel;
4231
4232    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4233        contacts_to_update.insert(session.user_id());
4234
4235        for project in left_room.left_projects.values() {
4236            project_left(project, session);
4237        }
4238
4239        room_id = RoomId::from_proto(left_room.room.id);
4240        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4241        livekit_room = mem::take(&mut left_room.room.livekit_room);
4242        delete_livekit_room = left_room.deleted;
4243        room = mem::take(&mut left_room.room);
4244        channel = mem::take(&mut left_room.channel);
4245
4246        room_updated(&room, &session.peer);
4247    } else {
4248        return Ok(());
4249    }
4250
4251    if let Some(channel) = channel {
4252        channel_updated(
4253            &channel,
4254            &room,
4255            &session.peer,
4256            &*session.connection_pool().await,
4257        );
4258    }
4259
4260    {
4261        let pool = session.connection_pool().await;
4262        for canceled_user_id in canceled_calls_to_user_ids {
4263            for connection_id in pool.user_connection_ids(canceled_user_id) {
4264                session
4265                    .peer
4266                    .send(
4267                        connection_id,
4268                        proto::CallCanceled {
4269                            room_id: room_id.to_proto(),
4270                        },
4271                    )
4272                    .trace_err();
4273            }
4274            contacts_to_update.insert(canceled_user_id);
4275        }
4276    }
4277
4278    for contact_user_id in contacts_to_update {
4279        update_user_contacts(contact_user_id, session).await?;
4280    }
4281
4282    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4283        live_kit
4284            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4285            .await
4286            .trace_err();
4287
4288        if delete_livekit_room {
4289            live_kit.delete_room(livekit_room).await.trace_err();
4290        }
4291    }
4292
4293    Ok(())
4294}
4295
4296async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4297    let left_channel_buffers = session
4298        .db()
4299        .await
4300        .leave_channel_buffers(session.connection_id)
4301        .await?;
4302
4303    for left_buffer in left_channel_buffers {
4304        channel_buffer_updated(
4305            session.connection_id,
4306            left_buffer.connections,
4307            &proto::UpdateChannelBufferCollaborators {
4308                channel_id: left_buffer.channel_id.to_proto(),
4309                collaborators: left_buffer.collaborators,
4310            },
4311            &session.peer,
4312        );
4313    }
4314
4315    Ok(())
4316}
4317
4318fn project_left(project: &db::LeftProject, session: &Session) {
4319    for connection_id in &project.connection_ids {
4320        if project.should_unshare {
4321            session
4322                .peer
4323                .send(
4324                    *connection_id,
4325                    proto::UnshareProject {
4326                        project_id: project.id.to_proto(),
4327                    },
4328                )
4329                .trace_err();
4330        } else {
4331            session
4332                .peer
4333                .send(
4334                    *connection_id,
4335                    proto::RemoveProjectCollaborator {
4336                        project_id: project.id.to_proto(),
4337                        peer_id: Some(session.connection_id.into()),
4338                    },
4339                )
4340                .trace_err();
4341        }
4342    }
4343}
4344
4345pub trait ResultExt {
4346    type Ok;
4347
4348    fn trace_err(self) -> Option<Self::Ok>;
4349}
4350
4351impl<T, E> ResultExt for Result<T, E>
4352where
4353    E: std::fmt::Debug,
4354{
4355    type Ok = T;
4356
4357    #[track_caller]
4358    fn trace_err(self) -> Option<T> {
4359        match self {
4360            Ok(value) => Some(value),
4361            Err(error) => {
4362                tracing::error!("{:?}", error);
4363                None
4364            }
4365        }
4366    }
4367}