rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   8        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   9        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  10        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use reqwest_client::ReqwestClient;
  37use rpc::proto::{MultiLspQuery, split_repository_update};
  38use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  39use tracing::Span;
  40
  41use futures::{
  42    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  43    stream::FuturesUnordered,
  44};
  45use prometheus::{IntGauge, register_int_gauge};
  46use rpc::{
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52};
  53use semantic_version::SemanticVersion;
  54use serde::{Serialize, Serializer};
  55use std::{
  56    any::TypeId,
  57    future::Future,
  58    marker::PhantomData,
  59    mem,
  60    net::SocketAddr,
  61    ops::{Deref, DerefMut},
  62    rc::Rc,
  63    sync::{
  64        Arc, OnceLock,
  65        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  66    },
  67    time::{Duration, Instant},
  68};
  69use time::OffsetDateTime;
  70use tokio::sync::{Semaphore, watch};
  71use tower::ServiceBuilder;
  72use tracing::{
  73    Instrument,
  74    field::{self},
  75    info_span, instrument,
  76};
  77
  78pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  79
  80// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  81pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  82
  83const MESSAGE_COUNT_PER_PAGE: usize = 100;
  84const MAX_MESSAGE_LEN: usize = 1024;
  85const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  86const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  87
  88static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  89
  90const TOTAL_DURATION_MS: &str = "total_duration_ms";
  91const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  92const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  93const HOST_WAITING_MS: &str = "host_waiting_ms";
  94
  95type MessageHandler =
  96    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  97
  98pub struct ConnectionGuard;
  99
 100impl ConnectionGuard {
 101    pub fn try_acquire() -> Result<Self, ()> {
 102        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 103        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 104            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 105            tracing::error!(
 106                "too many concurrent connections: {}",
 107                current_connections + 1
 108            );
 109            return Err(());
 110        }
 111        Ok(ConnectionGuard)
 112    }
 113}
 114
 115impl Drop for ConnectionGuard {
 116    fn drop(&mut self) {
 117        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 118    }
 119}
 120
 121struct Response<R> {
 122    peer: Arc<Peer>,
 123    receipt: Receipt<R>,
 124    responded: Arc<AtomicBool>,
 125}
 126
 127impl<R: RequestMessage> Response<R> {
 128    fn send(self, payload: R::Response) -> Result<()> {
 129        self.responded.store(true, SeqCst);
 130        self.peer.respond(self.receipt, payload)?;
 131        Ok(())
 132    }
 133}
 134
 135#[derive(Clone, Debug)]
 136pub enum Principal {
 137    User(User),
 138    Impersonated { user: User, admin: User },
 139}
 140
 141impl Principal {
 142    fn update_span(&self, span: &tracing::Span) {
 143        match &self {
 144            Principal::User(user) => {
 145                span.record("user_id", user.id.0);
 146                span.record("login", &user.github_login);
 147            }
 148            Principal::Impersonated { user, admin } => {
 149                span.record("user_id", user.id.0);
 150                span.record("login", &user.github_login);
 151                span.record("impersonator", &admin.github_login);
 152            }
 153        }
 154    }
 155}
 156
 157#[derive(Clone)]
 158struct MessageContext {
 159    session: Session,
 160    span: tracing::Span,
 161}
 162
 163impl Deref for MessageContext {
 164    type Target = Session;
 165
 166    fn deref(&self) -> &Self::Target {
 167        &self.session
 168    }
 169}
 170
 171impl MessageContext {
 172    pub fn forward_request<T: RequestMessage>(
 173        &self,
 174        receiver_id: ConnectionId,
 175        request: T,
 176    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 177        let request_start_time = Instant::now();
 178        let span = self.span.clone();
 179        tracing::info!("start forwarding request");
 180        self.peer
 181            .forward_request(self.connection_id, receiver_id, request)
 182            .inspect(move |_| {
 183                span.record(
 184                    HOST_WAITING_MS,
 185                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 186                );
 187            })
 188            .inspect_err(|_| tracing::error!("error forwarding request"))
 189            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 190    }
 191}
 192
 193#[derive(Clone)]
 194struct Session {
 195    principal: Principal,
 196    connection_id: ConnectionId,
 197    db: Arc<tokio::sync::Mutex<DbHandle>>,
 198    peer: Arc<Peer>,
 199    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 200    app_state: Arc<AppState>,
 201    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 202    /// The GeoIP country code for the user.
 203    #[allow(unused)]
 204    geoip_country_code: Option<String>,
 205    #[allow(unused)]
 206    system_id: Option<String>,
 207    _executor: Executor,
 208}
 209
 210impl Session {
 211    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 212        #[cfg(test)]
 213        tokio::task::yield_now().await;
 214        let guard = self.db.lock().await;
 215        #[cfg(test)]
 216        tokio::task::yield_now().await;
 217        guard
 218    }
 219
 220    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 221        #[cfg(test)]
 222        tokio::task::yield_now().await;
 223        let guard = self.connection_pool.lock();
 224        ConnectionPoolGuard {
 225            guard,
 226            _not_send: PhantomData,
 227        }
 228    }
 229
 230    fn is_staff(&self) -> bool {
 231        match &self.principal {
 232            Principal::User(user) => user.admin,
 233            Principal::Impersonated { .. } => true,
 234        }
 235    }
 236
 237    fn user_id(&self) -> UserId {
 238        match &self.principal {
 239            Principal::User(user) => user.id,
 240            Principal::Impersonated { user, .. } => user.id,
 241        }
 242    }
 243
 244    pub fn email(&self) -> Option<String> {
 245        match &self.principal {
 246            Principal::User(user) => user.email_address.clone(),
 247            Principal::Impersonated { user, .. } => user.email_address.clone(),
 248        }
 249    }
 250}
 251
 252impl Debug for Session {
 253    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 254        let mut result = f.debug_struct("Session");
 255        match &self.principal {
 256            Principal::User(user) => {
 257                result.field("user", &user.github_login);
 258            }
 259            Principal::Impersonated { user, admin } => {
 260                result.field("user", &user.github_login);
 261                result.field("impersonator", &admin.github_login);
 262            }
 263        }
 264        result.field("connection_id", &self.connection_id).finish()
 265    }
 266}
 267
 268struct DbHandle(Arc<Database>);
 269
 270impl Deref for DbHandle {
 271    type Target = Database;
 272
 273    fn deref(&self) -> &Self::Target {
 274        self.0.as_ref()
 275    }
 276}
 277
 278pub struct Server {
 279    id: parking_lot::Mutex<ServerId>,
 280    peer: Arc<Peer>,
 281    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 282    app_state: Arc<AppState>,
 283    handlers: HashMap<TypeId, MessageHandler>,
 284    teardown: watch::Sender<bool>,
 285}
 286
 287pub(crate) struct ConnectionPoolGuard<'a> {
 288    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 289    _not_send: PhantomData<Rc<()>>,
 290}
 291
 292#[derive(Serialize)]
 293pub struct ServerSnapshot<'a> {
 294    peer: &'a Peer,
 295    #[serde(serialize_with = "serialize_deref")]
 296    connection_pool: ConnectionPoolGuard<'a>,
 297}
 298
 299pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 300where
 301    S: Serializer,
 302    T: Deref<Target = U>,
 303    U: Serialize,
 304{
 305    Serialize::serialize(value.deref(), serializer)
 306}
 307
 308impl Server {
 309    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 310        let mut server = Self {
 311            id: parking_lot::Mutex::new(id),
 312            peer: Peer::new(id.0 as u32),
 313            app_state: app_state.clone(),
 314            connection_pool: Default::default(),
 315            handlers: Default::default(),
 316            teardown: watch::channel(false).0,
 317        };
 318
 319        server
 320            .add_request_handler(ping)
 321            .add_request_handler(create_room)
 322            .add_request_handler(join_room)
 323            .add_request_handler(rejoin_room)
 324            .add_request_handler(leave_room)
 325            .add_request_handler(set_room_participant_role)
 326            .add_request_handler(call)
 327            .add_request_handler(cancel_call)
 328            .add_message_handler(decline_call)
 329            .add_request_handler(update_participant_location)
 330            .add_request_handler(share_project)
 331            .add_message_handler(unshare_project)
 332            .add_request_handler(join_project)
 333            .add_message_handler(leave_project)
 334            .add_request_handler(update_project)
 335            .add_request_handler(update_worktree)
 336            .add_request_handler(update_repository)
 337            .add_request_handler(remove_repository)
 338            .add_message_handler(start_language_server)
 339            .add_message_handler(update_language_server)
 340            .add_message_handler(update_diagnostic_summary)
 341            .add_message_handler(update_worktree_settings)
 342            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 344            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 345            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 346            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 347            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 348            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 349            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 350            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 351            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 352            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 353            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 354            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 355            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 356            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 357            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 358            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 359            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 360            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 361            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 362            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 363            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 364            .add_request_handler(
 365                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 366            )
 367            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 368            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 369            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 370            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 371            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 372            .add_request_handler(
 373                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 374            )
 375            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 376            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 377            .add_request_handler(
 378                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 379            )
 380            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 381            .add_request_handler(
 382                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 383            )
 384            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 385            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 386            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 387            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 388            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 389            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 390            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 391            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 392            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 393            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 394            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 395            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 396            .add_request_handler(
 397                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 398            )
 399            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 400            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 401            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 402            .add_request_handler(multi_lsp_query)
 403            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 404            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 405            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 406            .add_message_handler(create_buffer_for_peer)
 407            .add_request_handler(update_buffer)
 408            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 409            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 410            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 411            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 412            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 413            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 414            .add_message_handler(
 415                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 416            )
 417            .add_request_handler(get_users)
 418            .add_request_handler(fuzzy_search_users)
 419            .add_request_handler(request_contact)
 420            .add_request_handler(remove_contact)
 421            .add_request_handler(respond_to_contact_request)
 422            .add_message_handler(subscribe_to_channels)
 423            .add_request_handler(create_channel)
 424            .add_request_handler(delete_channel)
 425            .add_request_handler(invite_channel_member)
 426            .add_request_handler(remove_channel_member)
 427            .add_request_handler(set_channel_member_role)
 428            .add_request_handler(set_channel_visibility)
 429            .add_request_handler(rename_channel)
 430            .add_request_handler(join_channel_buffer)
 431            .add_request_handler(leave_channel_buffer)
 432            .add_message_handler(update_channel_buffer)
 433            .add_request_handler(rejoin_channel_buffers)
 434            .add_request_handler(get_channel_members)
 435            .add_request_handler(respond_to_channel_invite)
 436            .add_request_handler(join_channel)
 437            .add_request_handler(join_channel_chat)
 438            .add_message_handler(leave_channel_chat)
 439            .add_request_handler(send_channel_message)
 440            .add_request_handler(remove_channel_message)
 441            .add_request_handler(update_channel_message)
 442            .add_request_handler(get_channel_messages)
 443            .add_request_handler(get_channel_messages_by_id)
 444            .add_request_handler(get_notifications)
 445            .add_request_handler(mark_notification_as_read)
 446            .add_request_handler(move_channel)
 447            .add_request_handler(reorder_channel)
 448            .add_request_handler(follow)
 449            .add_message_handler(unfollow)
 450            .add_message_handler(update_followers)
 451            .add_message_handler(acknowledge_channel_message)
 452            .add_message_handler(acknowledge_buffer_version)
 453            .add_request_handler(get_supermaven_api_key)
 454            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 455            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 456            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 457            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 458            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 459            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 460            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 461            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 462            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 463            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 464            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 465            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 466            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 467            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 468            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 469            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 470            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 471            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 472            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 473            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 474            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 475            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 476            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 477            .add_message_handler(update_context);
 478
 479        Arc::new(server)
 480    }
 481
 482    pub async fn start(&self) -> Result<()> {
 483        let server_id = *self.id.lock();
 484        let app_state = self.app_state.clone();
 485        let peer = self.peer.clone();
 486        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 487        let pool = self.connection_pool.clone();
 488        let livekit_client = self.app_state.livekit_client.clone();
 489
 490        let span = info_span!("start server");
 491        self.app_state.executor.spawn_detached(
 492            async move {
 493                tracing::info!("waiting for cleanup timeout");
 494                timeout.await;
 495                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 496
 497                app_state
 498                    .db
 499                    .delete_stale_channel_chat_participants(
 500                        &app_state.config.zed_environment,
 501                        server_id,
 502                    )
 503                    .await
 504                    .trace_err();
 505
 506                if let Some((room_ids, channel_ids)) = app_state
 507                    .db
 508                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 509                    .await
 510                    .trace_err()
 511                {
 512                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 513                    tracing::info!(
 514                        stale_channel_buffer_count = channel_ids.len(),
 515                        "retrieved stale channel buffers"
 516                    );
 517
 518                    for channel_id in channel_ids {
 519                        if let Some(refreshed_channel_buffer) = app_state
 520                            .db
 521                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 522                            .await
 523                            .trace_err()
 524                        {
 525                            for connection_id in refreshed_channel_buffer.connection_ids {
 526                                peer.send(
 527                                    connection_id,
 528                                    proto::UpdateChannelBufferCollaborators {
 529                                        channel_id: channel_id.to_proto(),
 530                                        collaborators: refreshed_channel_buffer
 531                                            .collaborators
 532                                            .clone(),
 533                                    },
 534                                )
 535                                .trace_err();
 536                            }
 537                        }
 538                    }
 539
 540                    for room_id in room_ids {
 541                        let mut contacts_to_update = HashSet::default();
 542                        let mut canceled_calls_to_user_ids = Vec::new();
 543                        let mut livekit_room = String::new();
 544                        let mut delete_livekit_room = false;
 545
 546                        if let Some(mut refreshed_room) = app_state
 547                            .db
 548                            .clear_stale_room_participants(room_id, server_id)
 549                            .await
 550                            .trace_err()
 551                        {
 552                            tracing::info!(
 553                                room_id = room_id.0,
 554                                new_participant_count = refreshed_room.room.participants.len(),
 555                                "refreshed room"
 556                            );
 557                            room_updated(&refreshed_room.room, &peer);
 558                            if let Some(channel) = refreshed_room.channel.as_ref() {
 559                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 560                            }
 561                            contacts_to_update
 562                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 563                            contacts_to_update
 564                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 565                            canceled_calls_to_user_ids =
 566                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 567                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 568                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 569                        }
 570
 571                        {
 572                            let pool = pool.lock();
 573                            for canceled_user_id in canceled_calls_to_user_ids {
 574                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 575                                    peer.send(
 576                                        connection_id,
 577                                        proto::CallCanceled {
 578                                            room_id: room_id.to_proto(),
 579                                        },
 580                                    )
 581                                    .trace_err();
 582                                }
 583                            }
 584                        }
 585
 586                        for user_id in contacts_to_update {
 587                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 588                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 589                            if let Some((busy, contacts)) = busy.zip(contacts) {
 590                                let pool = pool.lock();
 591                                let updated_contact = contact_for_user(user_id, busy, &pool);
 592                                for contact in contacts {
 593                                    if let db::Contact::Accepted {
 594                                        user_id: contact_user_id,
 595                                        ..
 596                                    } = contact
 597                                    {
 598                                        for contact_conn_id in
 599                                            pool.user_connection_ids(contact_user_id)
 600                                        {
 601                                            peer.send(
 602                                                contact_conn_id,
 603                                                proto::UpdateContacts {
 604                                                    contacts: vec![updated_contact.clone()],
 605                                                    remove_contacts: Default::default(),
 606                                                    incoming_requests: Default::default(),
 607                                                    remove_incoming_requests: Default::default(),
 608                                                    outgoing_requests: Default::default(),
 609                                                    remove_outgoing_requests: Default::default(),
 610                                                },
 611                                            )
 612                                            .trace_err();
 613                                        }
 614                                    }
 615                                }
 616                            }
 617                        }
 618
 619                        if let Some(live_kit) = livekit_client.as_ref()
 620                            && delete_livekit_room
 621                        {
 622                            live_kit.delete_room(livekit_room).await.trace_err();
 623                        }
 624                    }
 625                }
 626
 627                app_state
 628                    .db
 629                    .delete_stale_channel_chat_participants(
 630                        &app_state.config.zed_environment,
 631                        server_id,
 632                    )
 633                    .await
 634                    .trace_err();
 635
 636                app_state
 637                    .db
 638                    .clear_old_worktree_entries(server_id)
 639                    .await
 640                    .trace_err();
 641
 642                app_state
 643                    .db
 644                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 645                    .await
 646                    .trace_err();
 647            }
 648            .instrument(span),
 649        );
 650        Ok(())
 651    }
 652
 653    pub fn teardown(&self) {
 654        self.peer.teardown();
 655        self.connection_pool.lock().reset();
 656        let _ = self.teardown.send(true);
 657    }
 658
 659    #[cfg(test)]
 660    pub fn reset(&self, id: ServerId) {
 661        self.teardown();
 662        *self.id.lock() = id;
 663        self.peer.reset(id.0 as u32);
 664        let _ = self.teardown.send(false);
 665    }
 666
 667    #[cfg(test)]
 668    pub fn id(&self) -> ServerId {
 669        *self.id.lock()
 670    }
 671
 672    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 673    where
 674        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 675        Fut: 'static + Send + Future<Output = Result<()>>,
 676        M: EnvelopedMessage,
 677    {
 678        let prev_handler = self.handlers.insert(
 679            TypeId::of::<M>(),
 680            Box::new(move |envelope, session, span| {
 681                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 682                let received_at = envelope.received_at;
 683                tracing::info!("message received");
 684                let start_time = Instant::now();
 685                let future = (handler)(
 686                    *envelope,
 687                    MessageContext {
 688                        session,
 689                        span: span.clone(),
 690                    },
 691                );
 692                async move {
 693                    let result = future.await;
 694                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 695                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 696                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 697                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 698                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 699                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 700                    match result {
 701                        Err(error) => {
 702                            tracing::error!(?error, "error handling message")
 703                        }
 704                        Ok(()) => tracing::info!("finished handling message"),
 705                    }
 706                }
 707                .boxed()
 708            }),
 709        );
 710        if prev_handler.is_some() {
 711            panic!("registered a handler for the same message twice");
 712        }
 713        self
 714    }
 715
 716    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 717    where
 718        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 719        Fut: 'static + Send + Future<Output = Result<()>>,
 720        M: EnvelopedMessage,
 721    {
 722        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 723        self
 724    }
 725
 726    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 727    where
 728        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 729        Fut: Send + Future<Output = Result<()>>,
 730        M: RequestMessage,
 731    {
 732        let handler = Arc::new(handler);
 733        self.add_handler(move |envelope, session| {
 734            let receipt = envelope.receipt();
 735            let handler = handler.clone();
 736            async move {
 737                let peer = session.peer.clone();
 738                let responded = Arc::new(AtomicBool::default());
 739                let response = Response {
 740                    peer: peer.clone(),
 741                    responded: responded.clone(),
 742                    receipt,
 743                };
 744                match (handler)(envelope.payload, response, session).await {
 745                    Ok(()) => {
 746                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 747                            Ok(())
 748                        } else {
 749                            Err(anyhow!("handler did not send a response"))?
 750                        }
 751                    }
 752                    Err(error) => {
 753                        let proto_err = match &error {
 754                            Error::Internal(err) => err.to_proto(),
 755                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 756                        };
 757                        peer.respond_with_error(receipt, proto_err)?;
 758                        Err(error)
 759                    }
 760                }
 761            }
 762        })
 763    }
 764
 765    pub fn handle_connection(
 766        self: &Arc<Self>,
 767        connection: Connection,
 768        address: String,
 769        principal: Principal,
 770        zed_version: ZedVersion,
 771        release_channel: Option<String>,
 772        user_agent: Option<String>,
 773        geoip_country_code: Option<String>,
 774        system_id: Option<String>,
 775        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 776        executor: Executor,
 777        connection_guard: Option<ConnectionGuard>,
 778    ) -> impl Future<Output = ()> + use<> {
 779        let this = self.clone();
 780        let span = info_span!("handle connection", %address,
 781            connection_id=field::Empty,
 782            user_id=field::Empty,
 783            login=field::Empty,
 784            impersonator=field::Empty,
 785            user_agent=field::Empty,
 786            geoip_country_code=field::Empty,
 787            release_channel=field::Empty,
 788        );
 789        principal.update_span(&span);
 790        if let Some(user_agent) = user_agent {
 791            span.record("user_agent", user_agent);
 792        }
 793        if let Some(release_channel) = release_channel {
 794            span.record("release_channel", release_channel);
 795        }
 796
 797        if let Some(country_code) = geoip_country_code.as_ref() {
 798            span.record("geoip_country_code", country_code);
 799        }
 800
 801        let mut teardown = self.teardown.subscribe();
 802        async move {
 803            if *teardown.borrow() {
 804                tracing::error!("server is tearing down");
 805                return;
 806            }
 807
 808            let (connection_id, handle_io, mut incoming_rx) =
 809                this.peer.add_connection(connection, {
 810                    let executor = executor.clone();
 811                    move |duration| executor.sleep(duration)
 812                });
 813            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 814
 815            tracing::info!("connection opened");
 816
 817            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 818            let http_client = match ReqwestClient::user_agent(&user_agent) {
 819                Ok(http_client) => Arc::new(http_client),
 820                Err(error) => {
 821                    tracing::error!(?error, "failed to create HTTP client");
 822                    return;
 823                }
 824            };
 825
 826            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 827                |supermaven_admin_api_key| {
 828                    Arc::new(SupermavenAdminApi::new(
 829                        supermaven_admin_api_key.to_string(),
 830                        http_client.clone(),
 831                    ))
 832                },
 833            );
 834
 835            let session = Session {
 836                principal: principal.clone(),
 837                connection_id,
 838                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 839                peer: this.peer.clone(),
 840                connection_pool: this.connection_pool.clone(),
 841                app_state: this.app_state.clone(),
 842                geoip_country_code,
 843                system_id,
 844                _executor: executor.clone(),
 845                supermaven_client,
 846            };
 847
 848            if let Err(error) = this
 849                .send_initial_client_update(
 850                    connection_id,
 851                    zed_version,
 852                    send_connection_id,
 853                    &session,
 854                )
 855                .await
 856            {
 857                tracing::error!(?error, "failed to send initial client update");
 858                return;
 859            }
 860            drop(connection_guard);
 861
 862            let handle_io = handle_io.fuse();
 863            futures::pin_mut!(handle_io);
 864
 865            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 866            // This prevents deadlocks when e.g., client A performs a request to client B and
 867            // client B performs a request to client A. If both clients stop processing further
 868            // messages until their respective request completes, they won't have a chance to
 869            // respond to the other client's request and cause a deadlock.
 870            //
 871            // This arrangement ensures we will attempt to process earlier messages first, but fall
 872            // back to processing messages arrived later in the spirit of making progress.
 873            const MAX_CONCURRENT_HANDLERS: usize = 256;
 874            let mut foreground_message_handlers = FuturesUnordered::new();
 875            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 876            let get_concurrent_handlers = {
 877                let concurrent_handlers = concurrent_handlers.clone();
 878                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 879            };
 880            loop {
 881                let next_message = async {
 882                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 883                    let message = incoming_rx.next().await;
 884                    // Cache the concurrent_handlers here, so that we know what the
 885                    // queue looks like as each handler starts
 886                    (permit, message, get_concurrent_handlers())
 887                }
 888                .fuse();
 889                futures::pin_mut!(next_message);
 890                futures::select_biased! {
 891                    _ = teardown.changed().fuse() => return,
 892                    result = handle_io => {
 893                        if let Err(error) = result {
 894                            tracing::error!(?error, "error handling I/O");
 895                        }
 896                        break;
 897                    }
 898                    _ = foreground_message_handlers.next() => {}
 899                    next_message = next_message => {
 900                        let (permit, message, concurrent_handlers) = next_message;
 901                        if let Some(message) = message {
 902                            let type_name = message.payload_type_name();
 903                            // note: we copy all the fields from the parent span so we can query them in the logs.
 904                            // (https://github.com/tokio-rs/tracing/issues/2670).
 905                            let span = tracing::info_span!("receive message",
 906                                %connection_id,
 907                                %address,
 908                                type_name,
 909                                concurrent_handlers,
 910                                user_id=field::Empty,
 911                                login=field::Empty,
 912                                impersonator=field::Empty,
 913                                multi_lsp_query_request=field::Empty,
 914                                release_channel=field::Empty,
 915                                { TOTAL_DURATION_MS }=field::Empty,
 916                                { PROCESSING_DURATION_MS }=field::Empty,
 917                                { QUEUE_DURATION_MS }=field::Empty,
 918                                { HOST_WAITING_MS }=field::Empty
 919                            );
 920                            principal.update_span(&span);
 921                            let span_enter = span.enter();
 922                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 923                                let is_background = message.is_background();
 924                                let handle_message = (handler)(message, session.clone(), span.clone());
 925                                drop(span_enter);
 926
 927                                let handle_message = async move {
 928                                    handle_message.await;
 929                                    drop(permit);
 930                                }.instrument(span);
 931                                if is_background {
 932                                    executor.spawn_detached(handle_message);
 933                                } else {
 934                                    foreground_message_handlers.push(handle_message);
 935                                }
 936                            } else {
 937                                tracing::error!("no message handler");
 938                            }
 939                        } else {
 940                            tracing::info!("connection closed");
 941                            break;
 942                        }
 943                    }
 944                }
 945            }
 946
 947            drop(foreground_message_handlers);
 948            let concurrent_handlers = get_concurrent_handlers();
 949            tracing::info!(concurrent_handlers, "signing out");
 950            if let Err(error) = connection_lost(session, teardown, executor).await {
 951                tracing::error!(?error, "error signing out");
 952            }
 953        }
 954        .instrument(span)
 955    }
 956
 957    async fn send_initial_client_update(
 958        &self,
 959        connection_id: ConnectionId,
 960        zed_version: ZedVersion,
 961        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 962        session: &Session,
 963    ) -> Result<()> {
 964        self.peer.send(
 965            connection_id,
 966            proto::Hello {
 967                peer_id: Some(connection_id.into()),
 968            },
 969        )?;
 970        tracing::info!("sent hello message");
 971        if let Some(send_connection_id) = send_connection_id.take() {
 972            let _ = send_connection_id.send(connection_id);
 973        }
 974
 975        match &session.principal {
 976            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 977                if !user.connected_once {
 978                    self.peer.send(connection_id, proto::ShowContacts {})?;
 979                    self.app_state
 980                        .db
 981                        .set_user_connected_once(user.id, true)
 982                        .await?;
 983                }
 984
 985                let contacts = self.app_state.db.get_contacts(user.id).await?;
 986
 987                {
 988                    let mut pool = self.connection_pool.lock();
 989                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 990                    self.peer.send(
 991                        connection_id,
 992                        build_initial_contacts_update(contacts, &pool),
 993                    )?;
 994                }
 995
 996                if should_auto_subscribe_to_channels(zed_version) {
 997                    subscribe_user_to_channels(user.id, session).await?;
 998                }
 999
1000                if let Some(incoming_call) =
1001                    self.app_state.db.incoming_call_for_user(user.id).await?
1002                {
1003                    self.peer.send(connection_id, incoming_call)?;
1004                }
1005
1006                update_user_contacts(user.id, session).await?;
1007            }
1008        }
1009
1010        Ok(())
1011    }
1012
1013    pub async fn invite_code_redeemed(
1014        self: &Arc<Self>,
1015        inviter_id: UserId,
1016        invitee_id: UserId,
1017    ) -> Result<()> {
1018        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1019            && let Some(code) = &user.invite_code
1020        {
1021            let pool = self.connection_pool.lock();
1022            let invitee_contact = contact_for_user(invitee_id, false, &pool);
1023            for connection_id in pool.user_connection_ids(inviter_id) {
1024                self.peer.send(
1025                    connection_id,
1026                    proto::UpdateContacts {
1027                        contacts: vec![invitee_contact.clone()],
1028                        ..Default::default()
1029                    },
1030                )?;
1031                self.peer.send(
1032                    connection_id,
1033                    proto::UpdateInviteInfo {
1034                        url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1035                        count: user.invite_count as u32,
1036                    },
1037                )?;
1038            }
1039        }
1040        Ok(())
1041    }
1042
1043    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1044        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1045            && let Some(invite_code) = &user.invite_code
1046        {
1047            let pool = self.connection_pool.lock();
1048            for connection_id in pool.user_connection_ids(user_id) {
1049                self.peer.send(
1050                    connection_id,
1051                    proto::UpdateInviteInfo {
1052                        url: format!(
1053                            "{}{}",
1054                            self.app_state.config.invite_link_prefix, invite_code
1055                        ),
1056                        count: user.invite_count as u32,
1057                    },
1058                )?;
1059            }
1060        }
1061        Ok(())
1062    }
1063
1064    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1065        ServerSnapshot {
1066            connection_pool: ConnectionPoolGuard {
1067                guard: self.connection_pool.lock(),
1068                _not_send: PhantomData,
1069            },
1070            peer: &self.peer,
1071        }
1072    }
1073}
1074
1075impl Deref for ConnectionPoolGuard<'_> {
1076    type Target = ConnectionPool;
1077
1078    fn deref(&self) -> &Self::Target {
1079        &self.guard
1080    }
1081}
1082
1083impl DerefMut for ConnectionPoolGuard<'_> {
1084    fn deref_mut(&mut self) -> &mut Self::Target {
1085        &mut self.guard
1086    }
1087}
1088
1089impl Drop for ConnectionPoolGuard<'_> {
1090    fn drop(&mut self) {
1091        #[cfg(test)]
1092        self.check_invariants();
1093    }
1094}
1095
1096fn broadcast<F>(
1097    sender_id: Option<ConnectionId>,
1098    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1099    mut f: F,
1100) where
1101    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1102{
1103    for receiver_id in receiver_ids {
1104        if Some(receiver_id) != sender_id
1105            && let Err(error) = f(receiver_id)
1106        {
1107            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1108        }
1109    }
1110}
1111
1112pub struct ProtocolVersion(u32);
1113
1114impl Header for ProtocolVersion {
1115    fn name() -> &'static HeaderName {
1116        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1117        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1118    }
1119
1120    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1121    where
1122        Self: Sized,
1123        I: Iterator<Item = &'i axum::http::HeaderValue>,
1124    {
1125        let version = values
1126            .next()
1127            .ok_or_else(axum::headers::Error::invalid)?
1128            .to_str()
1129            .map_err(|_| axum::headers::Error::invalid())?
1130            .parse()
1131            .map_err(|_| axum::headers::Error::invalid())?;
1132        Ok(Self(version))
1133    }
1134
1135    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1136        values.extend([self.0.to_string().parse().unwrap()]);
1137    }
1138}
1139
1140pub struct AppVersionHeader(SemanticVersion);
1141impl Header for AppVersionHeader {
1142    fn name() -> &'static HeaderName {
1143        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1144        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1145    }
1146
1147    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1148    where
1149        Self: Sized,
1150        I: Iterator<Item = &'i axum::http::HeaderValue>,
1151    {
1152        let version = values
1153            .next()
1154            .ok_or_else(axum::headers::Error::invalid)?
1155            .to_str()
1156            .map_err(|_| axum::headers::Error::invalid())?
1157            .parse()
1158            .map_err(|_| axum::headers::Error::invalid())?;
1159        Ok(Self(version))
1160    }
1161
1162    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1163        values.extend([self.0.to_string().parse().unwrap()]);
1164    }
1165}
1166
1167#[derive(Debug)]
1168pub struct ReleaseChannelHeader(String);
1169
1170impl Header for ReleaseChannelHeader {
1171    fn name() -> &'static HeaderName {
1172        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1173        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1174    }
1175
1176    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1177    where
1178        Self: Sized,
1179        I: Iterator<Item = &'i axum::http::HeaderValue>,
1180    {
1181        Ok(Self(
1182            values
1183                .next()
1184                .ok_or_else(axum::headers::Error::invalid)?
1185                .to_str()
1186                .map_err(|_| axum::headers::Error::invalid())?
1187                .to_owned(),
1188        ))
1189    }
1190
1191    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1192        values.extend([self.0.parse().unwrap()]);
1193    }
1194}
1195
1196pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1197    Router::new()
1198        .route("/rpc", get(handle_websocket_request))
1199        .layer(
1200            ServiceBuilder::new()
1201                .layer(Extension(server.app_state.clone()))
1202                .layer(middleware::from_fn(auth::validate_header)),
1203        )
1204        .route("/metrics", get(handle_metrics))
1205        .layer(Extension(server))
1206}
1207
1208pub async fn handle_websocket_request(
1209    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1210    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1211    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1212    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1213    Extension(server): Extension<Arc<Server>>,
1214    Extension(principal): Extension<Principal>,
1215    user_agent: Option<TypedHeader<UserAgent>>,
1216    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1217    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1218    ws: WebSocketUpgrade,
1219) -> axum::response::Response {
1220    if protocol_version != rpc::PROTOCOL_VERSION {
1221        return (
1222            StatusCode::UPGRADE_REQUIRED,
1223            "client must be upgraded".to_string(),
1224        )
1225            .into_response();
1226    }
1227
1228    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1229        return (
1230            StatusCode::UPGRADE_REQUIRED,
1231            "no version header found".to_string(),
1232        )
1233            .into_response();
1234    };
1235
1236    let release_channel = release_channel_header.map(|header| header.0.0);
1237
1238    if !version.can_collaborate() {
1239        return (
1240            StatusCode::UPGRADE_REQUIRED,
1241            "client must be upgraded".to_string(),
1242        )
1243            .into_response();
1244    }
1245
1246    let socket_address = socket_address.to_string();
1247
1248    // Acquire connection guard before WebSocket upgrade
1249    let connection_guard = match ConnectionGuard::try_acquire() {
1250        Ok(guard) => guard,
1251        Err(()) => {
1252            return (
1253                StatusCode::SERVICE_UNAVAILABLE,
1254                "Too many concurrent connections",
1255            )
1256                .into_response();
1257        }
1258    };
1259
1260    ws.on_upgrade(move |socket| {
1261        let socket = socket
1262            .map_ok(to_tungstenite_message)
1263            .err_into()
1264            .with(|message| async move { to_axum_message(message) });
1265        let connection = Connection::new(Box::pin(socket));
1266        async move {
1267            server
1268                .handle_connection(
1269                    connection,
1270                    socket_address,
1271                    principal,
1272                    version,
1273                    release_channel,
1274                    user_agent.map(|header| header.to_string()),
1275                    country_code_header.map(|header| header.to_string()),
1276                    system_id_header.map(|header| header.to_string()),
1277                    None,
1278                    Executor::Production,
1279                    Some(connection_guard),
1280                )
1281                .await;
1282        }
1283    })
1284}
1285
1286pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1287    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1288    let connections_metric = CONNECTIONS_METRIC
1289        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1290
1291    let connections = server
1292        .connection_pool
1293        .lock()
1294        .connections()
1295        .filter(|connection| !connection.admin)
1296        .count();
1297    connections_metric.set(connections as _);
1298
1299    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1300    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1301        register_int_gauge!(
1302            "shared_projects",
1303            "number of open projects with one or more guests"
1304        )
1305        .unwrap()
1306    });
1307
1308    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1309    shared_projects_metric.set(shared_projects as _);
1310
1311    let encoder = prometheus::TextEncoder::new();
1312    let metric_families = prometheus::gather();
1313    let encoded_metrics = encoder
1314        .encode_to_string(&metric_families)
1315        .map_err(|err| anyhow!("{err}"))?;
1316    Ok(encoded_metrics)
1317}
1318
1319#[instrument(err, skip(executor))]
1320async fn connection_lost(
1321    session: Session,
1322    mut teardown: watch::Receiver<bool>,
1323    executor: Executor,
1324) -> Result<()> {
1325    session.peer.disconnect(session.connection_id);
1326    session
1327        .connection_pool()
1328        .await
1329        .remove_connection(session.connection_id)?;
1330
1331    session
1332        .db()
1333        .await
1334        .connection_lost(session.connection_id)
1335        .await
1336        .trace_err();
1337
1338    futures::select_biased! {
1339        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1340
1341            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1342            leave_room_for_session(&session, session.connection_id).await.trace_err();
1343            leave_channel_buffers_for_session(&session)
1344                .await
1345                .trace_err();
1346
1347            if !session
1348                .connection_pool()
1349                .await
1350                .is_user_online(session.user_id())
1351            {
1352                let db = session.db().await;
1353                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1354                    room_updated(&room, &session.peer);
1355                }
1356            }
1357
1358            update_user_contacts(session.user_id(), &session).await?;
1359        },
1360        _ = teardown.changed().fuse() => {}
1361    }
1362
1363    Ok(())
1364}
1365
1366/// Acknowledges a ping from a client, used to keep the connection alive.
1367async fn ping(
1368    _: proto::Ping,
1369    response: Response<proto::Ping>,
1370    _session: MessageContext,
1371) -> Result<()> {
1372    response.send(proto::Ack {})?;
1373    Ok(())
1374}
1375
1376/// Creates a new room for calling (outside of channels)
1377async fn create_room(
1378    _request: proto::CreateRoom,
1379    response: Response<proto::CreateRoom>,
1380    session: MessageContext,
1381) -> Result<()> {
1382    let livekit_room = nanoid::nanoid!(30);
1383
1384    let live_kit_connection_info = util::maybe!(async {
1385        let live_kit = session.app_state.livekit_client.as_ref();
1386        let live_kit = live_kit?;
1387        let user_id = session.user_id().to_string();
1388
1389        let token = live_kit
1390            .room_token(&livekit_room, &user_id.to_string())
1391            .trace_err()?;
1392
1393        Some(proto::LiveKitConnectionInfo {
1394            server_url: live_kit.url().into(),
1395            token,
1396            can_publish: true,
1397        })
1398    })
1399    .await;
1400
1401    let room = session
1402        .db()
1403        .await
1404        .create_room(session.user_id(), session.connection_id, &livekit_room)
1405        .await?;
1406
1407    response.send(proto::CreateRoomResponse {
1408        room: Some(room.clone()),
1409        live_kit_connection_info,
1410    })?;
1411
1412    update_user_contacts(session.user_id(), &session).await?;
1413    Ok(())
1414}
1415
1416/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1417async fn join_room(
1418    request: proto::JoinRoom,
1419    response: Response<proto::JoinRoom>,
1420    session: MessageContext,
1421) -> Result<()> {
1422    let room_id = RoomId::from_proto(request.id);
1423
1424    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1425
1426    if let Some(channel_id) = channel_id {
1427        return join_channel_internal(channel_id, Box::new(response), session).await;
1428    }
1429
1430    let joined_room = {
1431        let room = session
1432            .db()
1433            .await
1434            .join_room(room_id, session.user_id(), session.connection_id)
1435            .await?;
1436        room_updated(&room.room, &session.peer);
1437        room.into_inner()
1438    };
1439
1440    for connection_id in session
1441        .connection_pool()
1442        .await
1443        .user_connection_ids(session.user_id())
1444    {
1445        session
1446            .peer
1447            .send(
1448                connection_id,
1449                proto::CallCanceled {
1450                    room_id: room_id.to_proto(),
1451                },
1452            )
1453            .trace_err();
1454    }
1455
1456    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1457    {
1458        live_kit
1459            .room_token(
1460                &joined_room.room.livekit_room,
1461                &session.user_id().to_string(),
1462            )
1463            .trace_err()
1464            .map(|token| proto::LiveKitConnectionInfo {
1465                server_url: live_kit.url().into(),
1466                token,
1467                can_publish: true,
1468            })
1469    } else {
1470        None
1471    };
1472
1473    response.send(proto::JoinRoomResponse {
1474        room: Some(joined_room.room),
1475        channel_id: None,
1476        live_kit_connection_info,
1477    })?;
1478
1479    update_user_contacts(session.user_id(), &session).await?;
1480    Ok(())
1481}
1482
1483/// Rejoin room is used to reconnect to a room after connection errors.
1484async fn rejoin_room(
1485    request: proto::RejoinRoom,
1486    response: Response<proto::RejoinRoom>,
1487    session: MessageContext,
1488) -> Result<()> {
1489    let room;
1490    let channel;
1491    {
1492        let mut rejoined_room = session
1493            .db()
1494            .await
1495            .rejoin_room(request, session.user_id(), session.connection_id)
1496            .await?;
1497
1498        response.send(proto::RejoinRoomResponse {
1499            room: Some(rejoined_room.room.clone()),
1500            reshared_projects: rejoined_room
1501                .reshared_projects
1502                .iter()
1503                .map(|project| proto::ResharedProject {
1504                    id: project.id.to_proto(),
1505                    collaborators: project
1506                        .collaborators
1507                        .iter()
1508                        .map(|collaborator| collaborator.to_proto())
1509                        .collect(),
1510                })
1511                .collect(),
1512            rejoined_projects: rejoined_room
1513                .rejoined_projects
1514                .iter()
1515                .map(|rejoined_project| rejoined_project.to_proto())
1516                .collect(),
1517        })?;
1518        room_updated(&rejoined_room.room, &session.peer);
1519
1520        for project in &rejoined_room.reshared_projects {
1521            for collaborator in &project.collaborators {
1522                session
1523                    .peer
1524                    .send(
1525                        collaborator.connection_id,
1526                        proto::UpdateProjectCollaborator {
1527                            project_id: project.id.to_proto(),
1528                            old_peer_id: Some(project.old_connection_id.into()),
1529                            new_peer_id: Some(session.connection_id.into()),
1530                        },
1531                    )
1532                    .trace_err();
1533            }
1534
1535            broadcast(
1536                Some(session.connection_id),
1537                project
1538                    .collaborators
1539                    .iter()
1540                    .map(|collaborator| collaborator.connection_id),
1541                |connection_id| {
1542                    session.peer.forward_send(
1543                        session.connection_id,
1544                        connection_id,
1545                        proto::UpdateProject {
1546                            project_id: project.id.to_proto(),
1547                            worktrees: project.worktrees.clone(),
1548                        },
1549                    )
1550                },
1551            );
1552        }
1553
1554        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1555
1556        let rejoined_room = rejoined_room.into_inner();
1557
1558        room = rejoined_room.room;
1559        channel = rejoined_room.channel;
1560    }
1561
1562    if let Some(channel) = channel {
1563        channel_updated(
1564            &channel,
1565            &room,
1566            &session.peer,
1567            &*session.connection_pool().await,
1568        );
1569    }
1570
1571    update_user_contacts(session.user_id(), &session).await?;
1572    Ok(())
1573}
1574
1575fn notify_rejoined_projects(
1576    rejoined_projects: &mut Vec<RejoinedProject>,
1577    session: &Session,
1578) -> Result<()> {
1579    for project in rejoined_projects.iter() {
1580        for collaborator in &project.collaborators {
1581            session
1582                .peer
1583                .send(
1584                    collaborator.connection_id,
1585                    proto::UpdateProjectCollaborator {
1586                        project_id: project.id.to_proto(),
1587                        old_peer_id: Some(project.old_connection_id.into()),
1588                        new_peer_id: Some(session.connection_id.into()),
1589                    },
1590                )
1591                .trace_err();
1592        }
1593    }
1594
1595    for project in rejoined_projects {
1596        for worktree in mem::take(&mut project.worktrees) {
1597            // Stream this worktree's entries.
1598            let message = proto::UpdateWorktree {
1599                project_id: project.id.to_proto(),
1600                worktree_id: worktree.id,
1601                abs_path: worktree.abs_path.clone(),
1602                root_name: worktree.root_name,
1603                updated_entries: worktree.updated_entries,
1604                removed_entries: worktree.removed_entries,
1605                scan_id: worktree.scan_id,
1606                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1607                updated_repositories: worktree.updated_repositories,
1608                removed_repositories: worktree.removed_repositories,
1609            };
1610            for update in proto::split_worktree_update(message) {
1611                session.peer.send(session.connection_id, update)?;
1612            }
1613
1614            // Stream this worktree's diagnostics.
1615            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1616            if let Some(summary) = worktree_diagnostics.next() {
1617                let message = proto::UpdateDiagnosticSummary {
1618                    project_id: project.id.to_proto(),
1619                    worktree_id: worktree.id,
1620                    summary: Some(summary),
1621                    more_summaries: worktree_diagnostics.collect(),
1622                };
1623                session.peer.send(session.connection_id, message)?;
1624            }
1625
1626            for settings_file in worktree.settings_files {
1627                session.peer.send(
1628                    session.connection_id,
1629                    proto::UpdateWorktreeSettings {
1630                        project_id: project.id.to_proto(),
1631                        worktree_id: worktree.id,
1632                        path: settings_file.path,
1633                        content: Some(settings_file.content),
1634                        kind: Some(settings_file.kind.to_proto().into()),
1635                    },
1636                )?;
1637            }
1638        }
1639
1640        for repository in mem::take(&mut project.updated_repositories) {
1641            for update in split_repository_update(repository) {
1642                session.peer.send(session.connection_id, update)?;
1643            }
1644        }
1645
1646        for id in mem::take(&mut project.removed_repositories) {
1647            session.peer.send(
1648                session.connection_id,
1649                proto::RemoveRepository {
1650                    project_id: project.id.to_proto(),
1651                    id,
1652                },
1653            )?;
1654        }
1655    }
1656
1657    Ok(())
1658}
1659
1660/// leave room disconnects from the room.
1661async fn leave_room(
1662    _: proto::LeaveRoom,
1663    response: Response<proto::LeaveRoom>,
1664    session: MessageContext,
1665) -> Result<()> {
1666    leave_room_for_session(&session, session.connection_id).await?;
1667    response.send(proto::Ack {})?;
1668    Ok(())
1669}
1670
1671/// Updates the permissions of someone else in the room.
1672async fn set_room_participant_role(
1673    request: proto::SetRoomParticipantRole,
1674    response: Response<proto::SetRoomParticipantRole>,
1675    session: MessageContext,
1676) -> Result<()> {
1677    let user_id = UserId::from_proto(request.user_id);
1678    let role = ChannelRole::from(request.role());
1679
1680    let (livekit_room, can_publish) = {
1681        let room = session
1682            .db()
1683            .await
1684            .set_room_participant_role(
1685                session.user_id(),
1686                RoomId::from_proto(request.room_id),
1687                user_id,
1688                role,
1689            )
1690            .await?;
1691
1692        let livekit_room = room.livekit_room.clone();
1693        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1694        room_updated(&room, &session.peer);
1695        (livekit_room, can_publish)
1696    };
1697
1698    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1699        live_kit
1700            .update_participant(
1701                livekit_room.clone(),
1702                request.user_id.to_string(),
1703                livekit_api::proto::ParticipantPermission {
1704                    can_subscribe: true,
1705                    can_publish,
1706                    can_publish_data: can_publish,
1707                    hidden: false,
1708                    recorder: false,
1709                },
1710            )
1711            .await
1712            .trace_err();
1713    }
1714
1715    response.send(proto::Ack {})?;
1716    Ok(())
1717}
1718
1719/// Call someone else into the current room
1720async fn call(
1721    request: proto::Call,
1722    response: Response<proto::Call>,
1723    session: MessageContext,
1724) -> Result<()> {
1725    let room_id = RoomId::from_proto(request.room_id);
1726    let calling_user_id = session.user_id();
1727    let calling_connection_id = session.connection_id;
1728    let called_user_id = UserId::from_proto(request.called_user_id);
1729    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1730    if !session
1731        .db()
1732        .await
1733        .has_contact(calling_user_id, called_user_id)
1734        .await?
1735    {
1736        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1737    }
1738
1739    let incoming_call = {
1740        let (room, incoming_call) = &mut *session
1741            .db()
1742            .await
1743            .call(
1744                room_id,
1745                calling_user_id,
1746                calling_connection_id,
1747                called_user_id,
1748                initial_project_id,
1749            )
1750            .await?;
1751        room_updated(room, &session.peer);
1752        mem::take(incoming_call)
1753    };
1754    update_user_contacts(called_user_id, &session).await?;
1755
1756    let mut calls = session
1757        .connection_pool()
1758        .await
1759        .user_connection_ids(called_user_id)
1760        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1761        .collect::<FuturesUnordered<_>>();
1762
1763    while let Some(call_response) = calls.next().await {
1764        match call_response.as_ref() {
1765            Ok(_) => {
1766                response.send(proto::Ack {})?;
1767                return Ok(());
1768            }
1769            Err(_) => {
1770                call_response.trace_err();
1771            }
1772        }
1773    }
1774
1775    {
1776        let room = session
1777            .db()
1778            .await
1779            .call_failed(room_id, called_user_id)
1780            .await?;
1781        room_updated(&room, &session.peer);
1782    }
1783    update_user_contacts(called_user_id, &session).await?;
1784
1785    Err(anyhow!("failed to ring user"))?
1786}
1787
1788/// Cancel an outgoing call.
1789async fn cancel_call(
1790    request: proto::CancelCall,
1791    response: Response<proto::CancelCall>,
1792    session: MessageContext,
1793) -> Result<()> {
1794    let called_user_id = UserId::from_proto(request.called_user_id);
1795    let room_id = RoomId::from_proto(request.room_id);
1796    {
1797        let room = session
1798            .db()
1799            .await
1800            .cancel_call(room_id, session.connection_id, called_user_id)
1801            .await?;
1802        room_updated(&room, &session.peer);
1803    }
1804
1805    for connection_id in session
1806        .connection_pool()
1807        .await
1808        .user_connection_ids(called_user_id)
1809    {
1810        session
1811            .peer
1812            .send(
1813                connection_id,
1814                proto::CallCanceled {
1815                    room_id: room_id.to_proto(),
1816                },
1817            )
1818            .trace_err();
1819    }
1820    response.send(proto::Ack {})?;
1821
1822    update_user_contacts(called_user_id, &session).await?;
1823    Ok(())
1824}
1825
1826/// Decline an incoming call.
1827async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1828    let room_id = RoomId::from_proto(message.room_id);
1829    {
1830        let room = session
1831            .db()
1832            .await
1833            .decline_call(Some(room_id), session.user_id())
1834            .await?
1835            .context("declining call")?;
1836        room_updated(&room, &session.peer);
1837    }
1838
1839    for connection_id in session
1840        .connection_pool()
1841        .await
1842        .user_connection_ids(session.user_id())
1843    {
1844        session
1845            .peer
1846            .send(
1847                connection_id,
1848                proto::CallCanceled {
1849                    room_id: room_id.to_proto(),
1850                },
1851            )
1852            .trace_err();
1853    }
1854    update_user_contacts(session.user_id(), &session).await?;
1855    Ok(())
1856}
1857
1858/// Updates other participants in the room with your current location.
1859async fn update_participant_location(
1860    request: proto::UpdateParticipantLocation,
1861    response: Response<proto::UpdateParticipantLocation>,
1862    session: MessageContext,
1863) -> Result<()> {
1864    let room_id = RoomId::from_proto(request.room_id);
1865    let location = request.location.context("invalid location")?;
1866
1867    let db = session.db().await;
1868    let room = db
1869        .update_room_participant_location(room_id, session.connection_id, location)
1870        .await?;
1871
1872    room_updated(&room, &session.peer);
1873    response.send(proto::Ack {})?;
1874    Ok(())
1875}
1876
1877/// Share a project into the room.
1878async fn share_project(
1879    request: proto::ShareProject,
1880    response: Response<proto::ShareProject>,
1881    session: MessageContext,
1882) -> Result<()> {
1883    let (project_id, room) = &*session
1884        .db()
1885        .await
1886        .share_project(
1887            RoomId::from_proto(request.room_id),
1888            session.connection_id,
1889            &request.worktrees,
1890            request.is_ssh_project,
1891        )
1892        .await?;
1893    response.send(proto::ShareProjectResponse {
1894        project_id: project_id.to_proto(),
1895    })?;
1896    room_updated(room, &session.peer);
1897
1898    Ok(())
1899}
1900
1901/// Unshare a project from the room.
1902async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1903    let project_id = ProjectId::from_proto(message.project_id);
1904    unshare_project_internal(project_id, session.connection_id, &session).await
1905}
1906
1907async fn unshare_project_internal(
1908    project_id: ProjectId,
1909    connection_id: ConnectionId,
1910    session: &Session,
1911) -> Result<()> {
1912    let delete = {
1913        let room_guard = session
1914            .db()
1915            .await
1916            .unshare_project(project_id, connection_id)
1917            .await?;
1918
1919        let (delete, room, guest_connection_ids) = &*room_guard;
1920
1921        let message = proto::UnshareProject {
1922            project_id: project_id.to_proto(),
1923        };
1924
1925        broadcast(
1926            Some(connection_id),
1927            guest_connection_ids.iter().copied(),
1928            |conn_id| session.peer.send(conn_id, message.clone()),
1929        );
1930        if let Some(room) = room {
1931            room_updated(room, &session.peer);
1932        }
1933
1934        *delete
1935    };
1936
1937    if delete {
1938        let db = session.db().await;
1939        db.delete_project(project_id).await?;
1940    }
1941
1942    Ok(())
1943}
1944
1945/// Join someone elses shared project.
1946async fn join_project(
1947    request: proto::JoinProject,
1948    response: Response<proto::JoinProject>,
1949    session: MessageContext,
1950) -> Result<()> {
1951    let project_id = ProjectId::from_proto(request.project_id);
1952
1953    tracing::info!(%project_id, "join project");
1954
1955    let db = session.db().await;
1956    let (project, replica_id) = &mut *db
1957        .join_project(
1958            project_id,
1959            session.connection_id,
1960            session.user_id(),
1961            request.committer_name.clone(),
1962            request.committer_email.clone(),
1963        )
1964        .await?;
1965    drop(db);
1966    tracing::info!(%project_id, "join remote project");
1967    let collaborators = project
1968        .collaborators
1969        .iter()
1970        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1971        .map(|collaborator| collaborator.to_proto())
1972        .collect::<Vec<_>>();
1973    let project_id = project.id;
1974    let guest_user_id = session.user_id();
1975
1976    let worktrees = project
1977        .worktrees
1978        .iter()
1979        .map(|(id, worktree)| proto::WorktreeMetadata {
1980            id: *id,
1981            root_name: worktree.root_name.clone(),
1982            visible: worktree.visible,
1983            abs_path: worktree.abs_path.clone(),
1984        })
1985        .collect::<Vec<_>>();
1986
1987    let add_project_collaborator = proto::AddProjectCollaborator {
1988        project_id: project_id.to_proto(),
1989        collaborator: Some(proto::Collaborator {
1990            peer_id: Some(session.connection_id.into()),
1991            replica_id: replica_id.0 as u32,
1992            user_id: guest_user_id.to_proto(),
1993            is_host: false,
1994            committer_name: request.committer_name.clone(),
1995            committer_email: request.committer_email.clone(),
1996        }),
1997    };
1998
1999    for collaborator in &collaborators {
2000        session
2001            .peer
2002            .send(
2003                collaborator.peer_id.unwrap().into(),
2004                add_project_collaborator.clone(),
2005            )
2006            .trace_err();
2007    }
2008
2009    // First, we send the metadata associated with each worktree.
2010    let (language_servers, language_server_capabilities) = project
2011        .language_servers
2012        .clone()
2013        .into_iter()
2014        .map(|server| (server.server, server.capabilities))
2015        .unzip();
2016    response.send(proto::JoinProjectResponse {
2017        project_id: project.id.0 as u64,
2018        worktrees: worktrees.clone(),
2019        replica_id: replica_id.0 as u32,
2020        collaborators: collaborators.clone(),
2021        language_servers,
2022        language_server_capabilities,
2023        role: project.role.into(),
2024    })?;
2025
2026    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2027        // Stream this worktree's entries.
2028        let message = proto::UpdateWorktree {
2029            project_id: project_id.to_proto(),
2030            worktree_id,
2031            abs_path: worktree.abs_path.clone(),
2032            root_name: worktree.root_name,
2033            updated_entries: worktree.entries,
2034            removed_entries: Default::default(),
2035            scan_id: worktree.scan_id,
2036            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2037            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2038            removed_repositories: Default::default(),
2039        };
2040        for update in proto::split_worktree_update(message) {
2041            session.peer.send(session.connection_id, update.clone())?;
2042        }
2043
2044        // Stream this worktree's diagnostics.
2045        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2046        if let Some(summary) = worktree_diagnostics.next() {
2047            let message = proto::UpdateDiagnosticSummary {
2048                project_id: project.id.to_proto(),
2049                worktree_id: worktree.id,
2050                summary: Some(summary),
2051                more_summaries: worktree_diagnostics.collect(),
2052            };
2053            session.peer.send(session.connection_id, message)?;
2054        }
2055
2056        for settings_file in worktree.settings_files {
2057            session.peer.send(
2058                session.connection_id,
2059                proto::UpdateWorktreeSettings {
2060                    project_id: project_id.to_proto(),
2061                    worktree_id: worktree.id,
2062                    path: settings_file.path,
2063                    content: Some(settings_file.content),
2064                    kind: Some(settings_file.kind.to_proto() as i32),
2065                },
2066            )?;
2067        }
2068    }
2069
2070    for repository in mem::take(&mut project.repositories) {
2071        for update in split_repository_update(repository) {
2072            session.peer.send(session.connection_id, update)?;
2073        }
2074    }
2075
2076    for language_server in &project.language_servers {
2077        session.peer.send(
2078            session.connection_id,
2079            proto::UpdateLanguageServer {
2080                project_id: project_id.to_proto(),
2081                server_name: Some(language_server.server.name.clone()),
2082                language_server_id: language_server.server.id,
2083                variant: Some(
2084                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2085                        proto::LspDiskBasedDiagnosticsUpdated {},
2086                    ),
2087                ),
2088            },
2089        )?;
2090    }
2091
2092    Ok(())
2093}
2094
2095/// Leave someone elses shared project.
2096async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2097    let sender_id = session.connection_id;
2098    let project_id = ProjectId::from_proto(request.project_id);
2099    let db = session.db().await;
2100
2101    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2102    tracing::info!(
2103        %project_id,
2104        "leave project"
2105    );
2106
2107    project_left(project, &session);
2108    if let Some(room) = room {
2109        room_updated(room, &session.peer);
2110    }
2111
2112    Ok(())
2113}
2114
2115/// Updates other participants with changes to the project
2116async fn update_project(
2117    request: proto::UpdateProject,
2118    response: Response<proto::UpdateProject>,
2119    session: MessageContext,
2120) -> Result<()> {
2121    let project_id = ProjectId::from_proto(request.project_id);
2122    let (room, guest_connection_ids) = &*session
2123        .db()
2124        .await
2125        .update_project(project_id, session.connection_id, &request.worktrees)
2126        .await?;
2127    broadcast(
2128        Some(session.connection_id),
2129        guest_connection_ids.iter().copied(),
2130        |connection_id| {
2131            session
2132                .peer
2133                .forward_send(session.connection_id, connection_id, request.clone())
2134        },
2135    );
2136    if let Some(room) = room {
2137        room_updated(room, &session.peer);
2138    }
2139    response.send(proto::Ack {})?;
2140
2141    Ok(())
2142}
2143
2144/// Updates other participants with changes to the worktree
2145async fn update_worktree(
2146    request: proto::UpdateWorktree,
2147    response: Response<proto::UpdateWorktree>,
2148    session: MessageContext,
2149) -> Result<()> {
2150    let guest_connection_ids = session
2151        .db()
2152        .await
2153        .update_worktree(&request, session.connection_id)
2154        .await?;
2155
2156    broadcast(
2157        Some(session.connection_id),
2158        guest_connection_ids.iter().copied(),
2159        |connection_id| {
2160            session
2161                .peer
2162                .forward_send(session.connection_id, connection_id, request.clone())
2163        },
2164    );
2165    response.send(proto::Ack {})?;
2166    Ok(())
2167}
2168
2169async fn update_repository(
2170    request: proto::UpdateRepository,
2171    response: Response<proto::UpdateRepository>,
2172    session: MessageContext,
2173) -> Result<()> {
2174    let guest_connection_ids = session
2175        .db()
2176        .await
2177        .update_repository(&request, session.connection_id)
2178        .await?;
2179
2180    broadcast(
2181        Some(session.connection_id),
2182        guest_connection_ids.iter().copied(),
2183        |connection_id| {
2184            session
2185                .peer
2186                .forward_send(session.connection_id, connection_id, request.clone())
2187        },
2188    );
2189    response.send(proto::Ack {})?;
2190    Ok(())
2191}
2192
2193async fn remove_repository(
2194    request: proto::RemoveRepository,
2195    response: Response<proto::RemoveRepository>,
2196    session: MessageContext,
2197) -> Result<()> {
2198    let guest_connection_ids = session
2199        .db()
2200        .await
2201        .remove_repository(&request, session.connection_id)
2202        .await?;
2203
2204    broadcast(
2205        Some(session.connection_id),
2206        guest_connection_ids.iter().copied(),
2207        |connection_id| {
2208            session
2209                .peer
2210                .forward_send(session.connection_id, connection_id, request.clone())
2211        },
2212    );
2213    response.send(proto::Ack {})?;
2214    Ok(())
2215}
2216
2217/// Updates other participants with changes to the diagnostics
2218async fn update_diagnostic_summary(
2219    message: proto::UpdateDiagnosticSummary,
2220    session: MessageContext,
2221) -> Result<()> {
2222    let guest_connection_ids = session
2223        .db()
2224        .await
2225        .update_diagnostic_summary(&message, session.connection_id)
2226        .await?;
2227
2228    broadcast(
2229        Some(session.connection_id),
2230        guest_connection_ids.iter().copied(),
2231        |connection_id| {
2232            session
2233                .peer
2234                .forward_send(session.connection_id, connection_id, message.clone())
2235        },
2236    );
2237
2238    Ok(())
2239}
2240
2241/// Updates other participants with changes to the worktree settings
2242async fn update_worktree_settings(
2243    message: proto::UpdateWorktreeSettings,
2244    session: MessageContext,
2245) -> Result<()> {
2246    let guest_connection_ids = session
2247        .db()
2248        .await
2249        .update_worktree_settings(&message, session.connection_id)
2250        .await?;
2251
2252    broadcast(
2253        Some(session.connection_id),
2254        guest_connection_ids.iter().copied(),
2255        |connection_id| {
2256            session
2257                .peer
2258                .forward_send(session.connection_id, connection_id, message.clone())
2259        },
2260    );
2261
2262    Ok(())
2263}
2264
2265/// Notify other participants that a language server has started.
2266async fn start_language_server(
2267    request: proto::StartLanguageServer,
2268    session: MessageContext,
2269) -> Result<()> {
2270    let guest_connection_ids = session
2271        .db()
2272        .await
2273        .start_language_server(&request, session.connection_id)
2274        .await?;
2275
2276    broadcast(
2277        Some(session.connection_id),
2278        guest_connection_ids.iter().copied(),
2279        |connection_id| {
2280            session
2281                .peer
2282                .forward_send(session.connection_id, connection_id, request.clone())
2283        },
2284    );
2285    Ok(())
2286}
2287
2288/// Notify other participants that a language server has changed.
2289async fn update_language_server(
2290    request: proto::UpdateLanguageServer,
2291    session: MessageContext,
2292) -> Result<()> {
2293    let project_id = ProjectId::from_proto(request.project_id);
2294    let db = session.db().await;
2295
2296    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2297        && let Some(capabilities) = update.capabilities.clone()
2298    {
2299        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2300            .await?;
2301    }
2302
2303    let project_connection_ids = db
2304        .project_connection_ids(project_id, session.connection_id, true)
2305        .await?;
2306    broadcast(
2307        Some(session.connection_id),
2308        project_connection_ids.iter().copied(),
2309        |connection_id| {
2310            session
2311                .peer
2312                .forward_send(session.connection_id, connection_id, request.clone())
2313        },
2314    );
2315    Ok(())
2316}
2317
2318/// forward a project request to the host. These requests should be read only
2319/// as guests are allowed to send them.
2320async fn forward_read_only_project_request<T>(
2321    request: T,
2322    response: Response<T>,
2323    session: MessageContext,
2324) -> Result<()>
2325where
2326    T: EntityMessage + RequestMessage,
2327{
2328    let project_id = ProjectId::from_proto(request.remote_entity_id());
2329    let host_connection_id = session
2330        .db()
2331        .await
2332        .host_for_read_only_project_request(project_id, session.connection_id)
2333        .await?;
2334    let payload = session.forward_request(host_connection_id, request).await?;
2335    response.send(payload)?;
2336    Ok(())
2337}
2338
2339/// forward a project request to the host. These requests are disallowed
2340/// for guests.
2341async fn forward_mutating_project_request<T>(
2342    request: T,
2343    response: Response<T>,
2344    session: MessageContext,
2345) -> Result<()>
2346where
2347    T: EntityMessage + RequestMessage,
2348{
2349    let project_id = ProjectId::from_proto(request.remote_entity_id());
2350
2351    let host_connection_id = session
2352        .db()
2353        .await
2354        .host_for_mutating_project_request(project_id, session.connection_id)
2355        .await?;
2356    let payload = session.forward_request(host_connection_id, request).await?;
2357    response.send(payload)?;
2358    Ok(())
2359}
2360
2361async fn multi_lsp_query(
2362    request: MultiLspQuery,
2363    response: Response<MultiLspQuery>,
2364    session: MessageContext,
2365) -> Result<()> {
2366    tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2367    tracing::info!("multi_lsp_query message received");
2368    forward_mutating_project_request(request, response, session).await
2369}
2370
2371/// Notify other participants that a new buffer has been created
2372async fn create_buffer_for_peer(
2373    request: proto::CreateBufferForPeer,
2374    session: MessageContext,
2375) -> Result<()> {
2376    session
2377        .db()
2378        .await
2379        .check_user_is_project_host(
2380            ProjectId::from_proto(request.project_id),
2381            session.connection_id,
2382        )
2383        .await?;
2384    let peer_id = request.peer_id.context("invalid peer id")?;
2385    session
2386        .peer
2387        .forward_send(session.connection_id, peer_id.into(), request)?;
2388    Ok(())
2389}
2390
2391/// Notify other participants that a buffer has been updated. This is
2392/// allowed for guests as long as the update is limited to selections.
2393async fn update_buffer(
2394    request: proto::UpdateBuffer,
2395    response: Response<proto::UpdateBuffer>,
2396    session: MessageContext,
2397) -> Result<()> {
2398    let project_id = ProjectId::from_proto(request.project_id);
2399    let mut capability = Capability::ReadOnly;
2400
2401    for op in request.operations.iter() {
2402        match op.variant {
2403            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2404            Some(_) => capability = Capability::ReadWrite,
2405        }
2406    }
2407
2408    let host = {
2409        let guard = session
2410            .db()
2411            .await
2412            .connections_for_buffer_update(project_id, session.connection_id, capability)
2413            .await?;
2414
2415        let (host, guests) = &*guard;
2416
2417        broadcast(
2418            Some(session.connection_id),
2419            guests.clone(),
2420            |connection_id| {
2421                session
2422                    .peer
2423                    .forward_send(session.connection_id, connection_id, request.clone())
2424            },
2425        );
2426
2427        *host
2428    };
2429
2430    if host != session.connection_id {
2431        session.forward_request(host, request.clone()).await?;
2432    }
2433
2434    response.send(proto::Ack {})?;
2435    Ok(())
2436}
2437
2438async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2439    let project_id = ProjectId::from_proto(message.project_id);
2440
2441    let operation = message.operation.as_ref().context("invalid operation")?;
2442    let capability = match operation.variant.as_ref() {
2443        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2444            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2445                match buffer_op.variant {
2446                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2447                        Capability::ReadOnly
2448                    }
2449                    _ => Capability::ReadWrite,
2450                }
2451            } else {
2452                Capability::ReadWrite
2453            }
2454        }
2455        Some(_) => Capability::ReadWrite,
2456        None => Capability::ReadOnly,
2457    };
2458
2459    let guard = session
2460        .db()
2461        .await
2462        .connections_for_buffer_update(project_id, session.connection_id, capability)
2463        .await?;
2464
2465    let (host, guests) = &*guard;
2466
2467    broadcast(
2468        Some(session.connection_id),
2469        guests.iter().chain([host]).copied(),
2470        |connection_id| {
2471            session
2472                .peer
2473                .forward_send(session.connection_id, connection_id, message.clone())
2474        },
2475    );
2476
2477    Ok(())
2478}
2479
2480/// Notify other participants that a project has been updated.
2481async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2482    request: T,
2483    session: MessageContext,
2484) -> Result<()> {
2485    let project_id = ProjectId::from_proto(request.remote_entity_id());
2486    let project_connection_ids = session
2487        .db()
2488        .await
2489        .project_connection_ids(project_id, session.connection_id, false)
2490        .await?;
2491
2492    broadcast(
2493        Some(session.connection_id),
2494        project_connection_ids.iter().copied(),
2495        |connection_id| {
2496            session
2497                .peer
2498                .forward_send(session.connection_id, connection_id, request.clone())
2499        },
2500    );
2501    Ok(())
2502}
2503
2504/// Start following another user in a call.
2505async fn follow(
2506    request: proto::Follow,
2507    response: Response<proto::Follow>,
2508    session: MessageContext,
2509) -> Result<()> {
2510    let room_id = RoomId::from_proto(request.room_id);
2511    let project_id = request.project_id.map(ProjectId::from_proto);
2512    let leader_id = request.leader_id.context("invalid leader id")?.into();
2513    let follower_id = session.connection_id;
2514
2515    session
2516        .db()
2517        .await
2518        .check_room_participants(room_id, leader_id, session.connection_id)
2519        .await?;
2520
2521    let response_payload = session.forward_request(leader_id, request).await?;
2522    response.send(response_payload)?;
2523
2524    if let Some(project_id) = project_id {
2525        let room = session
2526            .db()
2527            .await
2528            .follow(room_id, project_id, leader_id, follower_id)
2529            .await?;
2530        room_updated(&room, &session.peer);
2531    }
2532
2533    Ok(())
2534}
2535
2536/// Stop following another user in a call.
2537async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2538    let room_id = RoomId::from_proto(request.room_id);
2539    let project_id = request.project_id.map(ProjectId::from_proto);
2540    let leader_id = request.leader_id.context("invalid leader id")?.into();
2541    let follower_id = session.connection_id;
2542
2543    session
2544        .db()
2545        .await
2546        .check_room_participants(room_id, leader_id, session.connection_id)
2547        .await?;
2548
2549    session
2550        .peer
2551        .forward_send(session.connection_id, leader_id, request)?;
2552
2553    if let Some(project_id) = project_id {
2554        let room = session
2555            .db()
2556            .await
2557            .unfollow(room_id, project_id, leader_id, follower_id)
2558            .await?;
2559        room_updated(&room, &session.peer);
2560    }
2561
2562    Ok(())
2563}
2564
2565/// Notify everyone following you of your current location.
2566async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2567    let room_id = RoomId::from_proto(request.room_id);
2568    let database = session.db.lock().await;
2569
2570    let connection_ids = if let Some(project_id) = request.project_id {
2571        let project_id = ProjectId::from_proto(project_id);
2572        database
2573            .project_connection_ids(project_id, session.connection_id, true)
2574            .await?
2575    } else {
2576        database
2577            .room_connection_ids(room_id, session.connection_id)
2578            .await?
2579    };
2580
2581    // For now, don't send view update messages back to that view's current leader.
2582    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2583        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2584        _ => None,
2585    });
2586
2587    for connection_id in connection_ids.iter().cloned() {
2588        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2589            session
2590                .peer
2591                .forward_send(session.connection_id, connection_id, request.clone())?;
2592        }
2593    }
2594    Ok(())
2595}
2596
2597/// Get public data about users.
2598async fn get_users(
2599    request: proto::GetUsers,
2600    response: Response<proto::GetUsers>,
2601    session: MessageContext,
2602) -> Result<()> {
2603    let user_ids = request
2604        .user_ids
2605        .into_iter()
2606        .map(UserId::from_proto)
2607        .collect();
2608    let users = session
2609        .db()
2610        .await
2611        .get_users_by_ids(user_ids)
2612        .await?
2613        .into_iter()
2614        .map(|user| proto::User {
2615            id: user.id.to_proto(),
2616            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2617            github_login: user.github_login,
2618            name: user.name,
2619        })
2620        .collect();
2621    response.send(proto::UsersResponse { users })?;
2622    Ok(())
2623}
2624
2625/// Search for users (to invite) buy Github login
2626async fn fuzzy_search_users(
2627    request: proto::FuzzySearchUsers,
2628    response: Response<proto::FuzzySearchUsers>,
2629    session: MessageContext,
2630) -> Result<()> {
2631    let query = request.query;
2632    let users = match query.len() {
2633        0 => vec![],
2634        1 | 2 => session
2635            .db()
2636            .await
2637            .get_user_by_github_login(&query)
2638            .await?
2639            .into_iter()
2640            .collect(),
2641        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2642    };
2643    let users = users
2644        .into_iter()
2645        .filter(|user| user.id != session.user_id())
2646        .map(|user| proto::User {
2647            id: user.id.to_proto(),
2648            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2649            github_login: user.github_login,
2650            name: user.name,
2651        })
2652        .collect();
2653    response.send(proto::UsersResponse { users })?;
2654    Ok(())
2655}
2656
2657/// Send a contact request to another user.
2658async fn request_contact(
2659    request: proto::RequestContact,
2660    response: Response<proto::RequestContact>,
2661    session: MessageContext,
2662) -> Result<()> {
2663    let requester_id = session.user_id();
2664    let responder_id = UserId::from_proto(request.responder_id);
2665    if requester_id == responder_id {
2666        return Err(anyhow!("cannot add yourself as a contact"))?;
2667    }
2668
2669    let notifications = session
2670        .db()
2671        .await
2672        .send_contact_request(requester_id, responder_id)
2673        .await?;
2674
2675    // Update outgoing contact requests of requester
2676    let mut update = proto::UpdateContacts::default();
2677    update.outgoing_requests.push(responder_id.to_proto());
2678    for connection_id in session
2679        .connection_pool()
2680        .await
2681        .user_connection_ids(requester_id)
2682    {
2683        session.peer.send(connection_id, update.clone())?;
2684    }
2685
2686    // Update incoming contact requests of responder
2687    let mut update = proto::UpdateContacts::default();
2688    update
2689        .incoming_requests
2690        .push(proto::IncomingContactRequest {
2691            requester_id: requester_id.to_proto(),
2692        });
2693    let connection_pool = session.connection_pool().await;
2694    for connection_id in connection_pool.user_connection_ids(responder_id) {
2695        session.peer.send(connection_id, update.clone())?;
2696    }
2697
2698    send_notifications(&connection_pool, &session.peer, notifications);
2699
2700    response.send(proto::Ack {})?;
2701    Ok(())
2702}
2703
2704/// Accept or decline a contact request
2705async fn respond_to_contact_request(
2706    request: proto::RespondToContactRequest,
2707    response: Response<proto::RespondToContactRequest>,
2708    session: MessageContext,
2709) -> Result<()> {
2710    let responder_id = session.user_id();
2711    let requester_id = UserId::from_proto(request.requester_id);
2712    let db = session.db().await;
2713    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2714        db.dismiss_contact_notification(responder_id, requester_id)
2715            .await?;
2716    } else {
2717        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2718
2719        let notifications = db
2720            .respond_to_contact_request(responder_id, requester_id, accept)
2721            .await?;
2722        let requester_busy = db.is_user_busy(requester_id).await?;
2723        let responder_busy = db.is_user_busy(responder_id).await?;
2724
2725        let pool = session.connection_pool().await;
2726        // Update responder with new contact
2727        let mut update = proto::UpdateContacts::default();
2728        if accept {
2729            update
2730                .contacts
2731                .push(contact_for_user(requester_id, requester_busy, &pool));
2732        }
2733        update
2734            .remove_incoming_requests
2735            .push(requester_id.to_proto());
2736        for connection_id in pool.user_connection_ids(responder_id) {
2737            session.peer.send(connection_id, update.clone())?;
2738        }
2739
2740        // Update requester with new contact
2741        let mut update = proto::UpdateContacts::default();
2742        if accept {
2743            update
2744                .contacts
2745                .push(contact_for_user(responder_id, responder_busy, &pool));
2746        }
2747        update
2748            .remove_outgoing_requests
2749            .push(responder_id.to_proto());
2750
2751        for connection_id in pool.user_connection_ids(requester_id) {
2752            session.peer.send(connection_id, update.clone())?;
2753        }
2754
2755        send_notifications(&pool, &session.peer, notifications);
2756    }
2757
2758    response.send(proto::Ack {})?;
2759    Ok(())
2760}
2761
2762/// Remove a contact.
2763async fn remove_contact(
2764    request: proto::RemoveContact,
2765    response: Response<proto::RemoveContact>,
2766    session: MessageContext,
2767) -> Result<()> {
2768    let requester_id = session.user_id();
2769    let responder_id = UserId::from_proto(request.user_id);
2770    let db = session.db().await;
2771    let (contact_accepted, deleted_notification_id) =
2772        db.remove_contact(requester_id, responder_id).await?;
2773
2774    let pool = session.connection_pool().await;
2775    // Update outgoing contact requests of requester
2776    let mut update = proto::UpdateContacts::default();
2777    if contact_accepted {
2778        update.remove_contacts.push(responder_id.to_proto());
2779    } else {
2780        update
2781            .remove_outgoing_requests
2782            .push(responder_id.to_proto());
2783    }
2784    for connection_id in pool.user_connection_ids(requester_id) {
2785        session.peer.send(connection_id, update.clone())?;
2786    }
2787
2788    // Update incoming contact requests of responder
2789    let mut update = proto::UpdateContacts::default();
2790    if contact_accepted {
2791        update.remove_contacts.push(requester_id.to_proto());
2792    } else {
2793        update
2794            .remove_incoming_requests
2795            .push(requester_id.to_proto());
2796    }
2797    for connection_id in pool.user_connection_ids(responder_id) {
2798        session.peer.send(connection_id, update.clone())?;
2799        if let Some(notification_id) = deleted_notification_id {
2800            session.peer.send(
2801                connection_id,
2802                proto::DeleteNotification {
2803                    notification_id: notification_id.to_proto(),
2804                },
2805            )?;
2806        }
2807    }
2808
2809    response.send(proto::Ack {})?;
2810    Ok(())
2811}
2812
2813fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2814    version.0.minor() < 139
2815}
2816
2817async fn subscribe_to_channels(
2818    _: proto::SubscribeToChannels,
2819    session: MessageContext,
2820) -> Result<()> {
2821    subscribe_user_to_channels(session.user_id(), &session).await?;
2822    Ok(())
2823}
2824
2825async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2826    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2827    let mut pool = session.connection_pool().await;
2828    for membership in &channels_for_user.channel_memberships {
2829        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2830    }
2831    session.peer.send(
2832        session.connection_id,
2833        build_update_user_channels(&channels_for_user),
2834    )?;
2835    session.peer.send(
2836        session.connection_id,
2837        build_channels_update(channels_for_user),
2838    )?;
2839    Ok(())
2840}
2841
2842/// Creates a new channel.
2843async fn create_channel(
2844    request: proto::CreateChannel,
2845    response: Response<proto::CreateChannel>,
2846    session: MessageContext,
2847) -> Result<()> {
2848    let db = session.db().await;
2849
2850    let parent_id = request.parent_id.map(ChannelId::from_proto);
2851    let (channel, membership) = db
2852        .create_channel(&request.name, parent_id, session.user_id())
2853        .await?;
2854
2855    let root_id = channel.root_id();
2856    let channel = Channel::from_model(channel);
2857
2858    response.send(proto::CreateChannelResponse {
2859        channel: Some(channel.to_proto()),
2860        parent_id: request.parent_id,
2861    })?;
2862
2863    let mut connection_pool = session.connection_pool().await;
2864    if let Some(membership) = membership {
2865        connection_pool.subscribe_to_channel(
2866            membership.user_id,
2867            membership.channel_id,
2868            membership.role,
2869        );
2870        let update = proto::UpdateUserChannels {
2871            channel_memberships: vec![proto::ChannelMembership {
2872                channel_id: membership.channel_id.to_proto(),
2873                role: membership.role.into(),
2874            }],
2875            ..Default::default()
2876        };
2877        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2878            session.peer.send(connection_id, update.clone())?;
2879        }
2880    }
2881
2882    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2883        if !role.can_see_channel(channel.visibility) {
2884            continue;
2885        }
2886
2887        let update = proto::UpdateChannels {
2888            channels: vec![channel.to_proto()],
2889            ..Default::default()
2890        };
2891        session.peer.send(connection_id, update.clone())?;
2892    }
2893
2894    Ok(())
2895}
2896
2897/// Delete a channel
2898async fn delete_channel(
2899    request: proto::DeleteChannel,
2900    response: Response<proto::DeleteChannel>,
2901    session: MessageContext,
2902) -> Result<()> {
2903    let db = session.db().await;
2904
2905    let channel_id = request.channel_id;
2906    let (root_channel, removed_channels) = db
2907        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2908        .await?;
2909    response.send(proto::Ack {})?;
2910
2911    // Notify members of removed channels
2912    let mut update = proto::UpdateChannels::default();
2913    update
2914        .delete_channels
2915        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2916
2917    let connection_pool = session.connection_pool().await;
2918    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2919        session.peer.send(connection_id, update.clone())?;
2920    }
2921
2922    Ok(())
2923}
2924
2925/// Invite someone to join a channel.
2926async fn invite_channel_member(
2927    request: proto::InviteChannelMember,
2928    response: Response<proto::InviteChannelMember>,
2929    session: MessageContext,
2930) -> Result<()> {
2931    let db = session.db().await;
2932    let channel_id = ChannelId::from_proto(request.channel_id);
2933    let invitee_id = UserId::from_proto(request.user_id);
2934    let InviteMemberResult {
2935        channel,
2936        notifications,
2937    } = db
2938        .invite_channel_member(
2939            channel_id,
2940            invitee_id,
2941            session.user_id(),
2942            request.role().into(),
2943        )
2944        .await?;
2945
2946    let update = proto::UpdateChannels {
2947        channel_invitations: vec![channel.to_proto()],
2948        ..Default::default()
2949    };
2950
2951    let connection_pool = session.connection_pool().await;
2952    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2953        session.peer.send(connection_id, update.clone())?;
2954    }
2955
2956    send_notifications(&connection_pool, &session.peer, notifications);
2957
2958    response.send(proto::Ack {})?;
2959    Ok(())
2960}
2961
2962/// remove someone from a channel
2963async fn remove_channel_member(
2964    request: proto::RemoveChannelMember,
2965    response: Response<proto::RemoveChannelMember>,
2966    session: MessageContext,
2967) -> Result<()> {
2968    let db = session.db().await;
2969    let channel_id = ChannelId::from_proto(request.channel_id);
2970    let member_id = UserId::from_proto(request.user_id);
2971
2972    let RemoveChannelMemberResult {
2973        membership_update,
2974        notification_id,
2975    } = db
2976        .remove_channel_member(channel_id, member_id, session.user_id())
2977        .await?;
2978
2979    let mut connection_pool = session.connection_pool().await;
2980    notify_membership_updated(
2981        &mut connection_pool,
2982        membership_update,
2983        member_id,
2984        &session.peer,
2985    );
2986    for connection_id in connection_pool.user_connection_ids(member_id) {
2987        if let Some(notification_id) = notification_id {
2988            session
2989                .peer
2990                .send(
2991                    connection_id,
2992                    proto::DeleteNotification {
2993                        notification_id: notification_id.to_proto(),
2994                    },
2995                )
2996                .trace_err();
2997        }
2998    }
2999
3000    response.send(proto::Ack {})?;
3001    Ok(())
3002}
3003
3004/// Toggle the channel between public and private.
3005/// Care is taken to maintain the invariant that public channels only descend from public channels,
3006/// (though members-only channels can appear at any point in the hierarchy).
3007async fn set_channel_visibility(
3008    request: proto::SetChannelVisibility,
3009    response: Response<proto::SetChannelVisibility>,
3010    session: MessageContext,
3011) -> Result<()> {
3012    let db = session.db().await;
3013    let channel_id = ChannelId::from_proto(request.channel_id);
3014    let visibility = request.visibility().into();
3015
3016    let channel_model = db
3017        .set_channel_visibility(channel_id, visibility, session.user_id())
3018        .await?;
3019    let root_id = channel_model.root_id();
3020    let channel = Channel::from_model(channel_model);
3021
3022    let mut connection_pool = session.connection_pool().await;
3023    for (user_id, role) in connection_pool
3024        .channel_user_ids(root_id)
3025        .collect::<Vec<_>>()
3026        .into_iter()
3027    {
3028        let update = if role.can_see_channel(channel.visibility) {
3029            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3030            proto::UpdateChannels {
3031                channels: vec![channel.to_proto()],
3032                ..Default::default()
3033            }
3034        } else {
3035            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3036            proto::UpdateChannels {
3037                delete_channels: vec![channel.id.to_proto()],
3038                ..Default::default()
3039            }
3040        };
3041
3042        for connection_id in connection_pool.user_connection_ids(user_id) {
3043            session.peer.send(connection_id, update.clone())?;
3044        }
3045    }
3046
3047    response.send(proto::Ack {})?;
3048    Ok(())
3049}
3050
3051/// Alter the role for a user in the channel.
3052async fn set_channel_member_role(
3053    request: proto::SetChannelMemberRole,
3054    response: Response<proto::SetChannelMemberRole>,
3055    session: MessageContext,
3056) -> Result<()> {
3057    let db = session.db().await;
3058    let channel_id = ChannelId::from_proto(request.channel_id);
3059    let member_id = UserId::from_proto(request.user_id);
3060    let result = db
3061        .set_channel_member_role(
3062            channel_id,
3063            session.user_id(),
3064            member_id,
3065            request.role().into(),
3066        )
3067        .await?;
3068
3069    match result {
3070        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3071            let mut connection_pool = session.connection_pool().await;
3072            notify_membership_updated(
3073                &mut connection_pool,
3074                membership_update,
3075                member_id,
3076                &session.peer,
3077            )
3078        }
3079        db::SetMemberRoleResult::InviteUpdated(channel) => {
3080            let update = proto::UpdateChannels {
3081                channel_invitations: vec![channel.to_proto()],
3082                ..Default::default()
3083            };
3084
3085            for connection_id in session
3086                .connection_pool()
3087                .await
3088                .user_connection_ids(member_id)
3089            {
3090                session.peer.send(connection_id, update.clone())?;
3091            }
3092        }
3093    }
3094
3095    response.send(proto::Ack {})?;
3096    Ok(())
3097}
3098
3099/// Change the name of a channel
3100async fn rename_channel(
3101    request: proto::RenameChannel,
3102    response: Response<proto::RenameChannel>,
3103    session: MessageContext,
3104) -> Result<()> {
3105    let db = session.db().await;
3106    let channel_id = ChannelId::from_proto(request.channel_id);
3107    let channel_model = db
3108        .rename_channel(channel_id, session.user_id(), &request.name)
3109        .await?;
3110    let root_id = channel_model.root_id();
3111    let channel = Channel::from_model(channel_model);
3112
3113    response.send(proto::RenameChannelResponse {
3114        channel: Some(channel.to_proto()),
3115    })?;
3116
3117    let connection_pool = session.connection_pool().await;
3118    let update = proto::UpdateChannels {
3119        channels: vec![channel.to_proto()],
3120        ..Default::default()
3121    };
3122    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3123        if role.can_see_channel(channel.visibility) {
3124            session.peer.send(connection_id, update.clone())?;
3125        }
3126    }
3127
3128    Ok(())
3129}
3130
3131/// Move a channel to a new parent.
3132async fn move_channel(
3133    request: proto::MoveChannel,
3134    response: Response<proto::MoveChannel>,
3135    session: MessageContext,
3136) -> Result<()> {
3137    let channel_id = ChannelId::from_proto(request.channel_id);
3138    let to = ChannelId::from_proto(request.to);
3139
3140    let (root_id, channels) = session
3141        .db()
3142        .await
3143        .move_channel(channel_id, to, session.user_id())
3144        .await?;
3145
3146    let connection_pool = session.connection_pool().await;
3147    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3148        let channels = channels
3149            .iter()
3150            .filter_map(|channel| {
3151                if role.can_see_channel(channel.visibility) {
3152                    Some(channel.to_proto())
3153                } else {
3154                    None
3155                }
3156            })
3157            .collect::<Vec<_>>();
3158        if channels.is_empty() {
3159            continue;
3160        }
3161
3162        let update = proto::UpdateChannels {
3163            channels,
3164            ..Default::default()
3165        };
3166
3167        session.peer.send(connection_id, update.clone())?;
3168    }
3169
3170    response.send(Ack {})?;
3171    Ok(())
3172}
3173
3174async fn reorder_channel(
3175    request: proto::ReorderChannel,
3176    response: Response<proto::ReorderChannel>,
3177    session: MessageContext,
3178) -> Result<()> {
3179    let channel_id = ChannelId::from_proto(request.channel_id);
3180    let direction = request.direction();
3181
3182    let updated_channels = session
3183        .db()
3184        .await
3185        .reorder_channel(channel_id, direction, session.user_id())
3186        .await?;
3187
3188    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3189        let connection_pool = session.connection_pool().await;
3190        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3191            let channels = updated_channels
3192                .iter()
3193                .filter_map(|channel| {
3194                    if role.can_see_channel(channel.visibility) {
3195                        Some(channel.to_proto())
3196                    } else {
3197                        None
3198                    }
3199                })
3200                .collect::<Vec<_>>();
3201
3202            if channels.is_empty() {
3203                continue;
3204            }
3205
3206            let update = proto::UpdateChannels {
3207                channels,
3208                ..Default::default()
3209            };
3210
3211            session.peer.send(connection_id, update.clone())?;
3212        }
3213    }
3214
3215    response.send(Ack {})?;
3216    Ok(())
3217}
3218
3219/// Get the list of channel members
3220async fn get_channel_members(
3221    request: proto::GetChannelMembers,
3222    response: Response<proto::GetChannelMembers>,
3223    session: MessageContext,
3224) -> Result<()> {
3225    let db = session.db().await;
3226    let channel_id = ChannelId::from_proto(request.channel_id);
3227    let limit = if request.limit == 0 {
3228        u16::MAX as u64
3229    } else {
3230        request.limit
3231    };
3232    let (members, users) = db
3233        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3234        .await?;
3235    response.send(proto::GetChannelMembersResponse { members, users })?;
3236    Ok(())
3237}
3238
3239/// Accept or decline a channel invitation.
3240async fn respond_to_channel_invite(
3241    request: proto::RespondToChannelInvite,
3242    response: Response<proto::RespondToChannelInvite>,
3243    session: MessageContext,
3244) -> Result<()> {
3245    let db = session.db().await;
3246    let channel_id = ChannelId::from_proto(request.channel_id);
3247    let RespondToChannelInvite {
3248        membership_update,
3249        notifications,
3250    } = db
3251        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3252        .await?;
3253
3254    let mut connection_pool = session.connection_pool().await;
3255    if let Some(membership_update) = membership_update {
3256        notify_membership_updated(
3257            &mut connection_pool,
3258            membership_update,
3259            session.user_id(),
3260            &session.peer,
3261        );
3262    } else {
3263        let update = proto::UpdateChannels {
3264            remove_channel_invitations: vec![channel_id.to_proto()],
3265            ..Default::default()
3266        };
3267
3268        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3269            session.peer.send(connection_id, update.clone())?;
3270        }
3271    };
3272
3273    send_notifications(&connection_pool, &session.peer, notifications);
3274
3275    response.send(proto::Ack {})?;
3276
3277    Ok(())
3278}
3279
3280/// Join the channels' room
3281async fn join_channel(
3282    request: proto::JoinChannel,
3283    response: Response<proto::JoinChannel>,
3284    session: MessageContext,
3285) -> Result<()> {
3286    let channel_id = ChannelId::from_proto(request.channel_id);
3287    join_channel_internal(channel_id, Box::new(response), session).await
3288}
3289
3290trait JoinChannelInternalResponse {
3291    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3292}
3293impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3294    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3295        Response::<proto::JoinChannel>::send(self, result)
3296    }
3297}
3298impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3299    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3300        Response::<proto::JoinRoom>::send(self, result)
3301    }
3302}
3303
3304async fn join_channel_internal(
3305    channel_id: ChannelId,
3306    response: Box<impl JoinChannelInternalResponse>,
3307    session: MessageContext,
3308) -> Result<()> {
3309    let joined_room = {
3310        let mut db = session.db().await;
3311        // If zed quits without leaving the room, and the user re-opens zed before the
3312        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3313        // room they were in.
3314        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3315            tracing::info!(
3316                stale_connection_id = %connection,
3317                "cleaning up stale connection",
3318            );
3319            drop(db);
3320            leave_room_for_session(&session, connection).await?;
3321            db = session.db().await;
3322        }
3323
3324        let (joined_room, membership_updated, role) = db
3325            .join_channel(channel_id, session.user_id(), session.connection_id)
3326            .await?;
3327
3328        let live_kit_connection_info =
3329            session
3330                .app_state
3331                .livekit_client
3332                .as_ref()
3333                .and_then(|live_kit| {
3334                    let (can_publish, token) = if role == ChannelRole::Guest {
3335                        (
3336                            false,
3337                            live_kit
3338                                .guest_token(
3339                                    &joined_room.room.livekit_room,
3340                                    &session.user_id().to_string(),
3341                                )
3342                                .trace_err()?,
3343                        )
3344                    } else {
3345                        (
3346                            true,
3347                            live_kit
3348                                .room_token(
3349                                    &joined_room.room.livekit_room,
3350                                    &session.user_id().to_string(),
3351                                )
3352                                .trace_err()?,
3353                        )
3354                    };
3355
3356                    Some(LiveKitConnectionInfo {
3357                        server_url: live_kit.url().into(),
3358                        token,
3359                        can_publish,
3360                    })
3361                });
3362
3363        response.send(proto::JoinRoomResponse {
3364            room: Some(joined_room.room.clone()),
3365            channel_id: joined_room
3366                .channel
3367                .as_ref()
3368                .map(|channel| channel.id.to_proto()),
3369            live_kit_connection_info,
3370        })?;
3371
3372        let mut connection_pool = session.connection_pool().await;
3373        if let Some(membership_updated) = membership_updated {
3374            notify_membership_updated(
3375                &mut connection_pool,
3376                membership_updated,
3377                session.user_id(),
3378                &session.peer,
3379            );
3380        }
3381
3382        room_updated(&joined_room.room, &session.peer);
3383
3384        joined_room
3385    };
3386
3387    channel_updated(
3388        &joined_room.channel.context("channel not returned")?,
3389        &joined_room.room,
3390        &session.peer,
3391        &*session.connection_pool().await,
3392    );
3393
3394    update_user_contacts(session.user_id(), &session).await?;
3395    Ok(())
3396}
3397
3398/// Start editing the channel notes
3399async fn join_channel_buffer(
3400    request: proto::JoinChannelBuffer,
3401    response: Response<proto::JoinChannelBuffer>,
3402    session: MessageContext,
3403) -> Result<()> {
3404    let db = session.db().await;
3405    let channel_id = ChannelId::from_proto(request.channel_id);
3406
3407    let open_response = db
3408        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3409        .await?;
3410
3411    let collaborators = open_response.collaborators.clone();
3412    response.send(open_response)?;
3413
3414    let update = UpdateChannelBufferCollaborators {
3415        channel_id: channel_id.to_proto(),
3416        collaborators: collaborators.clone(),
3417    };
3418    channel_buffer_updated(
3419        session.connection_id,
3420        collaborators
3421            .iter()
3422            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3423        &update,
3424        &session.peer,
3425    );
3426
3427    Ok(())
3428}
3429
3430/// Edit the channel notes
3431async fn update_channel_buffer(
3432    request: proto::UpdateChannelBuffer,
3433    session: MessageContext,
3434) -> Result<()> {
3435    let db = session.db().await;
3436    let channel_id = ChannelId::from_proto(request.channel_id);
3437
3438    let (collaborators, epoch, version) = db
3439        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3440        .await?;
3441
3442    channel_buffer_updated(
3443        session.connection_id,
3444        collaborators.clone(),
3445        &proto::UpdateChannelBuffer {
3446            channel_id: channel_id.to_proto(),
3447            operations: request.operations,
3448        },
3449        &session.peer,
3450    );
3451
3452    let pool = &*session.connection_pool().await;
3453
3454    let non_collaborators =
3455        pool.channel_connection_ids(channel_id)
3456            .filter_map(|(connection_id, _)| {
3457                if collaborators.contains(&connection_id) {
3458                    None
3459                } else {
3460                    Some(connection_id)
3461                }
3462            });
3463
3464    broadcast(None, non_collaborators, |peer_id| {
3465        session.peer.send(
3466            peer_id,
3467            proto::UpdateChannels {
3468                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3469                    channel_id: channel_id.to_proto(),
3470                    epoch: epoch as u64,
3471                    version: version.clone(),
3472                }],
3473                ..Default::default()
3474            },
3475        )
3476    });
3477
3478    Ok(())
3479}
3480
3481/// Rejoin the channel notes after a connection blip
3482async fn rejoin_channel_buffers(
3483    request: proto::RejoinChannelBuffers,
3484    response: Response<proto::RejoinChannelBuffers>,
3485    session: MessageContext,
3486) -> Result<()> {
3487    let db = session.db().await;
3488    let buffers = db
3489        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3490        .await?;
3491
3492    for rejoined_buffer in &buffers {
3493        let collaborators_to_notify = rejoined_buffer
3494            .buffer
3495            .collaborators
3496            .iter()
3497            .filter_map(|c| Some(c.peer_id?.into()));
3498        channel_buffer_updated(
3499            session.connection_id,
3500            collaborators_to_notify,
3501            &proto::UpdateChannelBufferCollaborators {
3502                channel_id: rejoined_buffer.buffer.channel_id,
3503                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3504            },
3505            &session.peer,
3506        );
3507    }
3508
3509    response.send(proto::RejoinChannelBuffersResponse {
3510        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3511    })?;
3512
3513    Ok(())
3514}
3515
3516/// Stop editing the channel notes
3517async fn leave_channel_buffer(
3518    request: proto::LeaveChannelBuffer,
3519    response: Response<proto::LeaveChannelBuffer>,
3520    session: MessageContext,
3521) -> Result<()> {
3522    let db = session.db().await;
3523    let channel_id = ChannelId::from_proto(request.channel_id);
3524
3525    let left_buffer = db
3526        .leave_channel_buffer(channel_id, session.connection_id)
3527        .await?;
3528
3529    response.send(Ack {})?;
3530
3531    channel_buffer_updated(
3532        session.connection_id,
3533        left_buffer.connections,
3534        &proto::UpdateChannelBufferCollaborators {
3535            channel_id: channel_id.to_proto(),
3536            collaborators: left_buffer.collaborators,
3537        },
3538        &session.peer,
3539    );
3540
3541    Ok(())
3542}
3543
3544fn channel_buffer_updated<T: EnvelopedMessage>(
3545    sender_id: ConnectionId,
3546    collaborators: impl IntoIterator<Item = ConnectionId>,
3547    message: &T,
3548    peer: &Peer,
3549) {
3550    broadcast(Some(sender_id), collaborators, |peer_id| {
3551        peer.send(peer_id, message.clone())
3552    });
3553}
3554
3555fn send_notifications(
3556    connection_pool: &ConnectionPool,
3557    peer: &Peer,
3558    notifications: db::NotificationBatch,
3559) {
3560    for (user_id, notification) in notifications {
3561        for connection_id in connection_pool.user_connection_ids(user_id) {
3562            if let Err(error) = peer.send(
3563                connection_id,
3564                proto::AddNotification {
3565                    notification: Some(notification.clone()),
3566                },
3567            ) {
3568                tracing::error!(
3569                    "failed to send notification to {:?} {}",
3570                    connection_id,
3571                    error
3572                );
3573            }
3574        }
3575    }
3576}
3577
3578/// Send a message to the channel
3579async fn send_channel_message(
3580    request: proto::SendChannelMessage,
3581    response: Response<proto::SendChannelMessage>,
3582    session: MessageContext,
3583) -> Result<()> {
3584    // Validate the message body.
3585    let body = request.body.trim().to_string();
3586    if body.len() > MAX_MESSAGE_LEN {
3587        return Err(anyhow!("message is too long"))?;
3588    }
3589    if body.is_empty() {
3590        return Err(anyhow!("message can't be blank"))?;
3591    }
3592
3593    // TODO: adjust mentions if body is trimmed
3594
3595    let timestamp = OffsetDateTime::now_utc();
3596    let nonce = request.nonce.context("nonce can't be blank")?;
3597
3598    let channel_id = ChannelId::from_proto(request.channel_id);
3599    let CreatedChannelMessage {
3600        message_id,
3601        participant_connection_ids,
3602        notifications,
3603    } = session
3604        .db()
3605        .await
3606        .create_channel_message(
3607            channel_id,
3608            session.user_id(),
3609            &body,
3610            &request.mentions,
3611            timestamp,
3612            nonce.clone().into(),
3613            request.reply_to_message_id.map(MessageId::from_proto),
3614        )
3615        .await?;
3616
3617    let message = proto::ChannelMessage {
3618        sender_id: session.user_id().to_proto(),
3619        id: message_id.to_proto(),
3620        body,
3621        mentions: request.mentions,
3622        timestamp: timestamp.unix_timestamp() as u64,
3623        nonce: Some(nonce),
3624        reply_to_message_id: request.reply_to_message_id,
3625        edited_at: None,
3626    };
3627    broadcast(
3628        Some(session.connection_id),
3629        participant_connection_ids.clone(),
3630        |connection| {
3631            session.peer.send(
3632                connection,
3633                proto::ChannelMessageSent {
3634                    channel_id: channel_id.to_proto(),
3635                    message: Some(message.clone()),
3636                },
3637            )
3638        },
3639    );
3640    response.send(proto::SendChannelMessageResponse {
3641        message: Some(message),
3642    })?;
3643
3644    let pool = &*session.connection_pool().await;
3645    let non_participants =
3646        pool.channel_connection_ids(channel_id)
3647            .filter_map(|(connection_id, _)| {
3648                if participant_connection_ids.contains(&connection_id) {
3649                    None
3650                } else {
3651                    Some(connection_id)
3652                }
3653            });
3654    broadcast(None, non_participants, |peer_id| {
3655        session.peer.send(
3656            peer_id,
3657            proto::UpdateChannels {
3658                latest_channel_message_ids: vec![proto::ChannelMessageId {
3659                    channel_id: channel_id.to_proto(),
3660                    message_id: message_id.to_proto(),
3661                }],
3662                ..Default::default()
3663            },
3664        )
3665    });
3666    send_notifications(pool, &session.peer, notifications);
3667
3668    Ok(())
3669}
3670
3671/// Delete a channel message
3672async fn remove_channel_message(
3673    request: proto::RemoveChannelMessage,
3674    response: Response<proto::RemoveChannelMessage>,
3675    session: MessageContext,
3676) -> Result<()> {
3677    let channel_id = ChannelId::from_proto(request.channel_id);
3678    let message_id = MessageId::from_proto(request.message_id);
3679    let (connection_ids, existing_notification_ids) = session
3680        .db()
3681        .await
3682        .remove_channel_message(channel_id, message_id, session.user_id())
3683        .await?;
3684
3685    broadcast(
3686        Some(session.connection_id),
3687        connection_ids,
3688        move |connection| {
3689            session.peer.send(connection, request.clone())?;
3690
3691            for notification_id in &existing_notification_ids {
3692                session.peer.send(
3693                    connection,
3694                    proto::DeleteNotification {
3695                        notification_id: (*notification_id).to_proto(),
3696                    },
3697                )?;
3698            }
3699
3700            Ok(())
3701        },
3702    );
3703    response.send(proto::Ack {})?;
3704    Ok(())
3705}
3706
3707async fn update_channel_message(
3708    request: proto::UpdateChannelMessage,
3709    response: Response<proto::UpdateChannelMessage>,
3710    session: MessageContext,
3711) -> Result<()> {
3712    let channel_id = ChannelId::from_proto(request.channel_id);
3713    let message_id = MessageId::from_proto(request.message_id);
3714    let updated_at = OffsetDateTime::now_utc();
3715    let UpdatedChannelMessage {
3716        message_id,
3717        participant_connection_ids,
3718        notifications,
3719        reply_to_message_id,
3720        timestamp,
3721        deleted_mention_notification_ids,
3722        updated_mention_notifications,
3723    } = session
3724        .db()
3725        .await
3726        .update_channel_message(
3727            channel_id,
3728            message_id,
3729            session.user_id(),
3730            request.body.as_str(),
3731            &request.mentions,
3732            updated_at,
3733        )
3734        .await?;
3735
3736    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3737
3738    let message = proto::ChannelMessage {
3739        sender_id: session.user_id().to_proto(),
3740        id: message_id.to_proto(),
3741        body: request.body.clone(),
3742        mentions: request.mentions.clone(),
3743        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3744        nonce: Some(nonce),
3745        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3746        edited_at: Some(updated_at.unix_timestamp() as u64),
3747    };
3748
3749    response.send(proto::Ack {})?;
3750
3751    let pool = &*session.connection_pool().await;
3752    broadcast(
3753        Some(session.connection_id),
3754        participant_connection_ids,
3755        |connection| {
3756            session.peer.send(
3757                connection,
3758                proto::ChannelMessageUpdate {
3759                    channel_id: channel_id.to_proto(),
3760                    message: Some(message.clone()),
3761                },
3762            )?;
3763
3764            for notification_id in &deleted_mention_notification_ids {
3765                session.peer.send(
3766                    connection,
3767                    proto::DeleteNotification {
3768                        notification_id: (*notification_id).to_proto(),
3769                    },
3770                )?;
3771            }
3772
3773            for notification in &updated_mention_notifications {
3774                session.peer.send(
3775                    connection,
3776                    proto::UpdateNotification {
3777                        notification: Some(notification.clone()),
3778                    },
3779                )?;
3780            }
3781
3782            Ok(())
3783        },
3784    );
3785
3786    send_notifications(pool, &session.peer, notifications);
3787
3788    Ok(())
3789}
3790
3791/// Mark a channel message as read
3792async fn acknowledge_channel_message(
3793    request: proto::AckChannelMessage,
3794    session: MessageContext,
3795) -> Result<()> {
3796    let channel_id = ChannelId::from_proto(request.channel_id);
3797    let message_id = MessageId::from_proto(request.message_id);
3798    let notifications = session
3799        .db()
3800        .await
3801        .observe_channel_message(channel_id, session.user_id(), message_id)
3802        .await?;
3803    send_notifications(
3804        &*session.connection_pool().await,
3805        &session.peer,
3806        notifications,
3807    );
3808    Ok(())
3809}
3810
3811/// Mark a buffer version as synced
3812async fn acknowledge_buffer_version(
3813    request: proto::AckBufferOperation,
3814    session: MessageContext,
3815) -> Result<()> {
3816    let buffer_id = BufferId::from_proto(request.buffer_id);
3817    session
3818        .db()
3819        .await
3820        .observe_buffer_version(
3821            buffer_id,
3822            session.user_id(),
3823            request.epoch as i32,
3824            &request.version,
3825        )
3826        .await?;
3827    Ok(())
3828}
3829
3830/// Get a Supermaven API key for the user
3831async fn get_supermaven_api_key(
3832    _request: proto::GetSupermavenApiKey,
3833    response: Response<proto::GetSupermavenApiKey>,
3834    session: MessageContext,
3835) -> Result<()> {
3836    let user_id: String = session.user_id().to_string();
3837    if !session.is_staff() {
3838        return Err(anyhow!("supermaven not enabled for this account"))?;
3839    }
3840
3841    let email = session.email().context("user must have an email")?;
3842
3843    let supermaven_admin_api = session
3844        .supermaven_client
3845        .as_ref()
3846        .context("supermaven not configured")?;
3847
3848    let result = supermaven_admin_api
3849        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3850        .await?;
3851
3852    response.send(proto::GetSupermavenApiKeyResponse {
3853        api_key: result.api_key,
3854    })?;
3855
3856    Ok(())
3857}
3858
3859/// Start receiving chat updates for a channel
3860async fn join_channel_chat(
3861    request: proto::JoinChannelChat,
3862    response: Response<proto::JoinChannelChat>,
3863    session: MessageContext,
3864) -> Result<()> {
3865    let channel_id = ChannelId::from_proto(request.channel_id);
3866
3867    let db = session.db().await;
3868    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3869        .await?;
3870    let messages = db
3871        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3872        .await?;
3873    response.send(proto::JoinChannelChatResponse {
3874        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3875        messages,
3876    })?;
3877    Ok(())
3878}
3879
3880/// Stop receiving chat updates for a channel
3881async fn leave_channel_chat(
3882    request: proto::LeaveChannelChat,
3883    session: MessageContext,
3884) -> Result<()> {
3885    let channel_id = ChannelId::from_proto(request.channel_id);
3886    session
3887        .db()
3888        .await
3889        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3890        .await?;
3891    Ok(())
3892}
3893
3894/// Retrieve the chat history for a channel
3895async fn get_channel_messages(
3896    request: proto::GetChannelMessages,
3897    response: Response<proto::GetChannelMessages>,
3898    session: MessageContext,
3899) -> Result<()> {
3900    let channel_id = ChannelId::from_proto(request.channel_id);
3901    let messages = session
3902        .db()
3903        .await
3904        .get_channel_messages(
3905            channel_id,
3906            session.user_id(),
3907            MESSAGE_COUNT_PER_PAGE,
3908            Some(MessageId::from_proto(request.before_message_id)),
3909        )
3910        .await?;
3911    response.send(proto::GetChannelMessagesResponse {
3912        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3913        messages,
3914    })?;
3915    Ok(())
3916}
3917
3918/// Retrieve specific chat messages
3919async fn get_channel_messages_by_id(
3920    request: proto::GetChannelMessagesById,
3921    response: Response<proto::GetChannelMessagesById>,
3922    session: MessageContext,
3923) -> Result<()> {
3924    let message_ids = request
3925        .message_ids
3926        .iter()
3927        .map(|id| MessageId::from_proto(*id))
3928        .collect::<Vec<_>>();
3929    let messages = session
3930        .db()
3931        .await
3932        .get_channel_messages_by_id(session.user_id(), &message_ids)
3933        .await?;
3934    response.send(proto::GetChannelMessagesResponse {
3935        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3936        messages,
3937    })?;
3938    Ok(())
3939}
3940
3941/// Retrieve the current users notifications
3942async fn get_notifications(
3943    request: proto::GetNotifications,
3944    response: Response<proto::GetNotifications>,
3945    session: MessageContext,
3946) -> Result<()> {
3947    let notifications = session
3948        .db()
3949        .await
3950        .get_notifications(
3951            session.user_id(),
3952            NOTIFICATION_COUNT_PER_PAGE,
3953            request.before_id.map(db::NotificationId::from_proto),
3954        )
3955        .await?;
3956    response.send(proto::GetNotificationsResponse {
3957        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3958        notifications,
3959    })?;
3960    Ok(())
3961}
3962
3963/// Mark notifications as read
3964async fn mark_notification_as_read(
3965    request: proto::MarkNotificationRead,
3966    response: Response<proto::MarkNotificationRead>,
3967    session: MessageContext,
3968) -> Result<()> {
3969    let database = &session.db().await;
3970    let notifications = database
3971        .mark_notification_as_read_by_id(
3972            session.user_id(),
3973            NotificationId::from_proto(request.notification_id),
3974        )
3975        .await?;
3976    send_notifications(
3977        &*session.connection_pool().await,
3978        &session.peer,
3979        notifications,
3980    );
3981    response.send(proto::Ack {})?;
3982    Ok(())
3983}
3984
3985fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3986    let message = match message {
3987        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3988        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3989        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3990        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3991        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3992            code: frame.code.into(),
3993            reason: frame.reason.as_str().to_owned().into(),
3994        })),
3995        // We should never receive a frame while reading the message, according
3996        // to the `tungstenite` maintainers:
3997        //
3998        // > It cannot occur when you read messages from the WebSocket, but it
3999        // > can be used when you want to send the raw frames (e.g. you want to
4000        // > send the frames to the WebSocket without composing the full message first).
4001        // >
4002        // > — https://github.com/snapview/tungstenite-rs/issues/268
4003        TungsteniteMessage::Frame(_) => {
4004            bail!("received an unexpected frame while reading the message")
4005        }
4006    };
4007
4008    Ok(message)
4009}
4010
4011fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4012    match message {
4013        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4014        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4015        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4016        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4017        AxumMessage::Close(frame) => {
4018            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4019                code: frame.code.into(),
4020                reason: frame.reason.as_ref().into(),
4021            }))
4022        }
4023    }
4024}
4025
4026fn notify_membership_updated(
4027    connection_pool: &mut ConnectionPool,
4028    result: MembershipUpdated,
4029    user_id: UserId,
4030    peer: &Peer,
4031) {
4032    for membership in &result.new_channels.channel_memberships {
4033        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4034    }
4035    for channel_id in &result.removed_channels {
4036        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4037    }
4038
4039    let user_channels_update = proto::UpdateUserChannels {
4040        channel_memberships: result
4041            .new_channels
4042            .channel_memberships
4043            .iter()
4044            .map(|cm| proto::ChannelMembership {
4045                channel_id: cm.channel_id.to_proto(),
4046                role: cm.role.into(),
4047            })
4048            .collect(),
4049        ..Default::default()
4050    };
4051
4052    let mut update = build_channels_update(result.new_channels);
4053    update.delete_channels = result
4054        .removed_channels
4055        .into_iter()
4056        .map(|id| id.to_proto())
4057        .collect();
4058    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4059
4060    for connection_id in connection_pool.user_connection_ids(user_id) {
4061        peer.send(connection_id, user_channels_update.clone())
4062            .trace_err();
4063        peer.send(connection_id, update.clone()).trace_err();
4064    }
4065}
4066
4067fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4068    proto::UpdateUserChannels {
4069        channel_memberships: channels
4070            .channel_memberships
4071            .iter()
4072            .map(|m| proto::ChannelMembership {
4073                channel_id: m.channel_id.to_proto(),
4074                role: m.role.into(),
4075            })
4076            .collect(),
4077        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4078        observed_channel_message_id: channels.observed_channel_messages.clone(),
4079    }
4080}
4081
4082fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4083    let mut update = proto::UpdateChannels::default();
4084
4085    for channel in channels.channels {
4086        update.channels.push(channel.to_proto());
4087    }
4088
4089    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4090    update.latest_channel_message_ids = channels.latest_channel_messages;
4091
4092    for (channel_id, participants) in channels.channel_participants {
4093        update
4094            .channel_participants
4095            .push(proto::ChannelParticipants {
4096                channel_id: channel_id.to_proto(),
4097                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4098            });
4099    }
4100
4101    for channel in channels.invited_channels {
4102        update.channel_invitations.push(channel.to_proto());
4103    }
4104
4105    update
4106}
4107
4108fn build_initial_contacts_update(
4109    contacts: Vec<db::Contact>,
4110    pool: &ConnectionPool,
4111) -> proto::UpdateContacts {
4112    let mut update = proto::UpdateContacts::default();
4113
4114    for contact in contacts {
4115        match contact {
4116            db::Contact::Accepted { user_id, busy } => {
4117                update.contacts.push(contact_for_user(user_id, busy, pool));
4118            }
4119            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4120            db::Contact::Incoming { user_id } => {
4121                update
4122                    .incoming_requests
4123                    .push(proto::IncomingContactRequest {
4124                        requester_id: user_id.to_proto(),
4125                    })
4126            }
4127        }
4128    }
4129
4130    update
4131}
4132
4133fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4134    proto::Contact {
4135        user_id: user_id.to_proto(),
4136        online: pool.is_user_online(user_id),
4137        busy,
4138    }
4139}
4140
4141fn room_updated(room: &proto::Room, peer: &Peer) {
4142    broadcast(
4143        None,
4144        room.participants
4145            .iter()
4146            .filter_map(|participant| Some(participant.peer_id?.into())),
4147        |peer_id| {
4148            peer.send(
4149                peer_id,
4150                proto::RoomUpdated {
4151                    room: Some(room.clone()),
4152                },
4153            )
4154        },
4155    );
4156}
4157
4158fn channel_updated(
4159    channel: &db::channel::Model,
4160    room: &proto::Room,
4161    peer: &Peer,
4162    pool: &ConnectionPool,
4163) {
4164    let participants = room
4165        .participants
4166        .iter()
4167        .map(|p| p.user_id)
4168        .collect::<Vec<_>>();
4169
4170    broadcast(
4171        None,
4172        pool.channel_connection_ids(channel.root_id())
4173            .filter_map(|(channel_id, role)| {
4174                role.can_see_channel(channel.visibility)
4175                    .then_some(channel_id)
4176            }),
4177        |peer_id| {
4178            peer.send(
4179                peer_id,
4180                proto::UpdateChannels {
4181                    channel_participants: vec![proto::ChannelParticipants {
4182                        channel_id: channel.id.to_proto(),
4183                        participant_user_ids: participants.clone(),
4184                    }],
4185                    ..Default::default()
4186                },
4187            )
4188        },
4189    );
4190}
4191
4192async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4193    let db = session.db().await;
4194
4195    let contacts = db.get_contacts(user_id).await?;
4196    let busy = db.is_user_busy(user_id).await?;
4197
4198    let pool = session.connection_pool().await;
4199    let updated_contact = contact_for_user(user_id, busy, &pool);
4200    for contact in contacts {
4201        if let db::Contact::Accepted {
4202            user_id: contact_user_id,
4203            ..
4204        } = contact
4205        {
4206            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4207                session
4208                    .peer
4209                    .send(
4210                        contact_conn_id,
4211                        proto::UpdateContacts {
4212                            contacts: vec![updated_contact.clone()],
4213                            remove_contacts: Default::default(),
4214                            incoming_requests: Default::default(),
4215                            remove_incoming_requests: Default::default(),
4216                            outgoing_requests: Default::default(),
4217                            remove_outgoing_requests: Default::default(),
4218                        },
4219                    )
4220                    .trace_err();
4221            }
4222        }
4223    }
4224    Ok(())
4225}
4226
4227async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4228    let mut contacts_to_update = HashSet::default();
4229
4230    let room_id;
4231    let canceled_calls_to_user_ids;
4232    let livekit_room;
4233    let delete_livekit_room;
4234    let room;
4235    let channel;
4236
4237    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4238        contacts_to_update.insert(session.user_id());
4239
4240        for project in left_room.left_projects.values() {
4241            project_left(project, session);
4242        }
4243
4244        room_id = RoomId::from_proto(left_room.room.id);
4245        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4246        livekit_room = mem::take(&mut left_room.room.livekit_room);
4247        delete_livekit_room = left_room.deleted;
4248        room = mem::take(&mut left_room.room);
4249        channel = mem::take(&mut left_room.channel);
4250
4251        room_updated(&room, &session.peer);
4252    } else {
4253        return Ok(());
4254    }
4255
4256    if let Some(channel) = channel {
4257        channel_updated(
4258            &channel,
4259            &room,
4260            &session.peer,
4261            &*session.connection_pool().await,
4262        );
4263    }
4264
4265    {
4266        let pool = session.connection_pool().await;
4267        for canceled_user_id in canceled_calls_to_user_ids {
4268            for connection_id in pool.user_connection_ids(canceled_user_id) {
4269                session
4270                    .peer
4271                    .send(
4272                        connection_id,
4273                        proto::CallCanceled {
4274                            room_id: room_id.to_proto(),
4275                        },
4276                    )
4277                    .trace_err();
4278            }
4279            contacts_to_update.insert(canceled_user_id);
4280        }
4281    }
4282
4283    for contact_user_id in contacts_to_update {
4284        update_user_contacts(contact_user_id, session).await?;
4285    }
4286
4287    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4288        live_kit
4289            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4290            .await
4291            .trace_err();
4292
4293        if delete_livekit_room {
4294            live_kit.delete_room(livekit_room).await.trace_err();
4295        }
4296    }
4297
4298    Ok(())
4299}
4300
4301async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4302    let left_channel_buffers = session
4303        .db()
4304        .await
4305        .leave_channel_buffers(session.connection_id)
4306        .await?;
4307
4308    for left_buffer in left_channel_buffers {
4309        channel_buffer_updated(
4310            session.connection_id,
4311            left_buffer.connections,
4312            &proto::UpdateChannelBufferCollaborators {
4313                channel_id: left_buffer.channel_id.to_proto(),
4314                collaborators: left_buffer.collaborators,
4315            },
4316            &session.peer,
4317        );
4318    }
4319
4320    Ok(())
4321}
4322
4323fn project_left(project: &db::LeftProject, session: &Session) {
4324    for connection_id in &project.connection_ids {
4325        if project.should_unshare {
4326            session
4327                .peer
4328                .send(
4329                    *connection_id,
4330                    proto::UnshareProject {
4331                        project_id: project.id.to_proto(),
4332                    },
4333                )
4334                .trace_err();
4335        } else {
4336            session
4337                .peer
4338                .send(
4339                    *connection_id,
4340                    proto::RemoveProjectCollaborator {
4341                        project_id: project.id.to_proto(),
4342                        peer_id: Some(session.connection_id.into()),
4343                    },
4344                )
4345                .trace_err();
4346        }
4347    }
4348}
4349
4350pub trait ResultExt {
4351    type Ok;
4352
4353    fn trace_err(self) -> Option<Self::Ok>;
4354}
4355
4356impl<T, E> ResultExt for Result<T, E>
4357where
4358    E: std::fmt::Debug,
4359{
4360    type Ok = T;
4361
4362    #[track_caller]
4363    fn trace_err(self) -> Option<T> {
4364        match self {
4365            Ok(value) => Some(value),
4366            Err(error) => {
4367                tracing::error!("{:?}", error);
4368                None
4369            }
4370        }
4371    }
4372}