rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   8        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   9        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  10        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use reqwest_client::ReqwestClient;
  37use rpc::proto::{MultiLspQuery, split_repository_update};
  38use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  39use tracing::Span;
  40
  41use futures::{
  42    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  43    stream::FuturesUnordered,
  44};
  45use prometheus::{IntGauge, register_int_gauge};
  46use rpc::{
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52};
  53use semantic_version::SemanticVersion;
  54use serde::{Serialize, Serializer};
  55use std::{
  56    any::TypeId,
  57    future::Future,
  58    marker::PhantomData,
  59    mem,
  60    net::SocketAddr,
  61    ops::{Deref, DerefMut},
  62    rc::Rc,
  63    sync::{
  64        Arc, OnceLock,
  65        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  66    },
  67    time::{Duration, Instant},
  68};
  69use time::OffsetDateTime;
  70use tokio::sync::{Semaphore, watch};
  71use tower::ServiceBuilder;
  72use tracing::{
  73    Instrument,
  74    field::{self},
  75    info_span, instrument,
  76};
  77
  78pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  79
  80// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  81pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  82
  83const MESSAGE_COUNT_PER_PAGE: usize = 100;
  84const MAX_MESSAGE_LEN: usize = 1024;
  85const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  86const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  87
  88static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  89
  90const TOTAL_DURATION_MS: &str = "total_duration_ms";
  91const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  92const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  93const HOST_WAITING_MS: &str = "host_waiting_ms";
  94
  95type MessageHandler =
  96    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  97
  98pub struct ConnectionGuard;
  99
 100impl ConnectionGuard {
 101    pub fn try_acquire() -> Result<Self, ()> {
 102        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 103        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 104            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 105            tracing::error!(
 106                "too many concurrent connections: {}",
 107                current_connections + 1
 108            );
 109            return Err(());
 110        }
 111        Ok(ConnectionGuard)
 112    }
 113}
 114
 115impl Drop for ConnectionGuard {
 116    fn drop(&mut self) {
 117        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 118    }
 119}
 120
 121struct Response<R> {
 122    peer: Arc<Peer>,
 123    receipt: Receipt<R>,
 124    responded: Arc<AtomicBool>,
 125}
 126
 127impl<R: RequestMessage> Response<R> {
 128    fn send(self, payload: R::Response) -> Result<()> {
 129        self.responded.store(true, SeqCst);
 130        self.peer.respond(self.receipt, payload)?;
 131        Ok(())
 132    }
 133}
 134
 135#[derive(Clone, Debug)]
 136pub enum Principal {
 137    User(User),
 138    Impersonated { user: User, admin: User },
 139}
 140
 141impl Principal {
 142    fn update_span(&self, span: &tracing::Span) {
 143        match &self {
 144            Principal::User(user) => {
 145                span.record("user_id", user.id.0);
 146                span.record("login", &user.github_login);
 147            }
 148            Principal::Impersonated { user, admin } => {
 149                span.record("user_id", user.id.0);
 150                span.record("login", &user.github_login);
 151                span.record("impersonator", &admin.github_login);
 152            }
 153        }
 154    }
 155}
 156
 157#[derive(Clone)]
 158struct MessageContext {
 159    session: Session,
 160    span: tracing::Span,
 161}
 162
 163impl Deref for MessageContext {
 164    type Target = Session;
 165
 166    fn deref(&self) -> &Self::Target {
 167        &self.session
 168    }
 169}
 170
 171impl MessageContext {
 172    pub fn forward_request<T: RequestMessage>(
 173        &self,
 174        receiver_id: ConnectionId,
 175        request: T,
 176    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 177        let request_start_time = Instant::now();
 178        let span = self.span.clone();
 179        tracing::info!("start forwarding request");
 180        self.peer
 181            .forward_request(self.connection_id, receiver_id, request)
 182            .inspect(move |_| {
 183                span.record(
 184                    HOST_WAITING_MS,
 185                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 186                );
 187            })
 188            .inspect_err(|_| tracing::error!("error forwarding request"))
 189            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 190    }
 191}
 192
 193#[derive(Clone)]
 194struct Session {
 195    principal: Principal,
 196    connection_id: ConnectionId,
 197    db: Arc<tokio::sync::Mutex<DbHandle>>,
 198    peer: Arc<Peer>,
 199    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 200    app_state: Arc<AppState>,
 201    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 202    /// The GeoIP country code for the user.
 203    #[allow(unused)]
 204    geoip_country_code: Option<String>,
 205    #[allow(unused)]
 206    system_id: Option<String>,
 207    _executor: Executor,
 208}
 209
 210impl Session {
 211    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 212        #[cfg(test)]
 213        tokio::task::yield_now().await;
 214        let guard = self.db.lock().await;
 215        #[cfg(test)]
 216        tokio::task::yield_now().await;
 217        guard
 218    }
 219
 220    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 221        #[cfg(test)]
 222        tokio::task::yield_now().await;
 223        let guard = self.connection_pool.lock();
 224        ConnectionPoolGuard {
 225            guard,
 226            _not_send: PhantomData,
 227        }
 228    }
 229
 230    fn is_staff(&self) -> bool {
 231        match &self.principal {
 232            Principal::User(user) => user.admin,
 233            Principal::Impersonated { .. } => true,
 234        }
 235    }
 236
 237    fn user_id(&self) -> UserId {
 238        match &self.principal {
 239            Principal::User(user) => user.id,
 240            Principal::Impersonated { user, .. } => user.id,
 241        }
 242    }
 243
 244    pub fn email(&self) -> Option<String> {
 245        match &self.principal {
 246            Principal::User(user) => user.email_address.clone(),
 247            Principal::Impersonated { user, .. } => user.email_address.clone(),
 248        }
 249    }
 250}
 251
 252impl Debug for Session {
 253    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 254        let mut result = f.debug_struct("Session");
 255        match &self.principal {
 256            Principal::User(user) => {
 257                result.field("user", &user.github_login);
 258            }
 259            Principal::Impersonated { user, admin } => {
 260                result.field("user", &user.github_login);
 261                result.field("impersonator", &admin.github_login);
 262            }
 263        }
 264        result.field("connection_id", &self.connection_id).finish()
 265    }
 266}
 267
 268struct DbHandle(Arc<Database>);
 269
 270impl Deref for DbHandle {
 271    type Target = Database;
 272
 273    fn deref(&self) -> &Self::Target {
 274        self.0.as_ref()
 275    }
 276}
 277
 278pub struct Server {
 279    id: parking_lot::Mutex<ServerId>,
 280    peer: Arc<Peer>,
 281    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 282    app_state: Arc<AppState>,
 283    handlers: HashMap<TypeId, MessageHandler>,
 284    teardown: watch::Sender<bool>,
 285}
 286
 287pub(crate) struct ConnectionPoolGuard<'a> {
 288    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 289    _not_send: PhantomData<Rc<()>>,
 290}
 291
 292#[derive(Serialize)]
 293pub struct ServerSnapshot<'a> {
 294    peer: &'a Peer,
 295    #[serde(serialize_with = "serialize_deref")]
 296    connection_pool: ConnectionPoolGuard<'a>,
 297}
 298
 299pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 300where
 301    S: Serializer,
 302    T: Deref<Target = U>,
 303    U: Serialize,
 304{
 305    Serialize::serialize(value.deref(), serializer)
 306}
 307
 308impl Server {
 309    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 310        let mut server = Self {
 311            id: parking_lot::Mutex::new(id),
 312            peer: Peer::new(id.0 as u32),
 313            app_state: app_state.clone(),
 314            connection_pool: Default::default(),
 315            handlers: Default::default(),
 316            teardown: watch::channel(false).0,
 317        };
 318
 319        server
 320            .add_request_handler(ping)
 321            .add_request_handler(create_room)
 322            .add_request_handler(join_room)
 323            .add_request_handler(rejoin_room)
 324            .add_request_handler(leave_room)
 325            .add_request_handler(set_room_participant_role)
 326            .add_request_handler(call)
 327            .add_request_handler(cancel_call)
 328            .add_message_handler(decline_call)
 329            .add_request_handler(update_participant_location)
 330            .add_request_handler(share_project)
 331            .add_message_handler(unshare_project)
 332            .add_request_handler(join_project)
 333            .add_message_handler(leave_project)
 334            .add_request_handler(update_project)
 335            .add_request_handler(update_worktree)
 336            .add_request_handler(update_repository)
 337            .add_request_handler(remove_repository)
 338            .add_message_handler(start_language_server)
 339            .add_message_handler(update_language_server)
 340            .add_message_handler(update_diagnostic_summary)
 341            .add_message_handler(update_worktree_settings)
 342            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 344            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 345            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 346            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 347            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 348            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 349            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 350            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 351            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 352            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 353            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 354            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 355            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 356            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 357            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 358            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 359            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 360            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 361            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 362            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 363            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 364            .add_request_handler(
 365                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 366            )
 367            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 368            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 369            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 370            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 371            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 372            .add_request_handler(
 373                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 374            )
 375            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 376            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 377            .add_request_handler(
 378                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 379            )
 380            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 381            .add_request_handler(
 382                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 383            )
 384            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 385            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 386            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 387            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 388            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 389            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 390            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 391            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 392            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 393            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 394            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 395            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 396            .add_request_handler(
 397                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 398            )
 399            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 400            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 401            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 402            .add_request_handler(multi_lsp_query)
 403            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 404            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 405            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 406            .add_message_handler(create_buffer_for_peer)
 407            .add_request_handler(update_buffer)
 408            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 409            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 410            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 411            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 412            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 413            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 414            .add_message_handler(
 415                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 416            )
 417            .add_request_handler(get_users)
 418            .add_request_handler(fuzzy_search_users)
 419            .add_request_handler(request_contact)
 420            .add_request_handler(remove_contact)
 421            .add_request_handler(respond_to_contact_request)
 422            .add_message_handler(subscribe_to_channels)
 423            .add_request_handler(create_channel)
 424            .add_request_handler(delete_channel)
 425            .add_request_handler(invite_channel_member)
 426            .add_request_handler(remove_channel_member)
 427            .add_request_handler(set_channel_member_role)
 428            .add_request_handler(set_channel_visibility)
 429            .add_request_handler(rename_channel)
 430            .add_request_handler(join_channel_buffer)
 431            .add_request_handler(leave_channel_buffer)
 432            .add_message_handler(update_channel_buffer)
 433            .add_request_handler(rejoin_channel_buffers)
 434            .add_request_handler(get_channel_members)
 435            .add_request_handler(respond_to_channel_invite)
 436            .add_request_handler(join_channel)
 437            .add_request_handler(join_channel_chat)
 438            .add_message_handler(leave_channel_chat)
 439            .add_request_handler(send_channel_message)
 440            .add_request_handler(remove_channel_message)
 441            .add_request_handler(update_channel_message)
 442            .add_request_handler(get_channel_messages)
 443            .add_request_handler(get_channel_messages_by_id)
 444            .add_request_handler(get_notifications)
 445            .add_request_handler(mark_notification_as_read)
 446            .add_request_handler(move_channel)
 447            .add_request_handler(reorder_channel)
 448            .add_request_handler(follow)
 449            .add_message_handler(unfollow)
 450            .add_message_handler(update_followers)
 451            .add_message_handler(acknowledge_channel_message)
 452            .add_message_handler(acknowledge_buffer_version)
 453            .add_request_handler(get_supermaven_api_key)
 454            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 455            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 456            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 457            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 458            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 459            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 460            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 461            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 462            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 463            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 464            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 465            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 466            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 467            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 468            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 469            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 470            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 471            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 472            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 473            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 474            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 475            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 476            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 477            .add_message_handler(update_context);
 478
 479        Arc::new(server)
 480    }
 481
 482    pub async fn start(&self) -> Result<()> {
 483        let server_id = *self.id.lock();
 484        let app_state = self.app_state.clone();
 485        let peer = self.peer.clone();
 486        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 487        let pool = self.connection_pool.clone();
 488        let livekit_client = self.app_state.livekit_client.clone();
 489
 490        let span = info_span!("start server");
 491        self.app_state.executor.spawn_detached(
 492            async move {
 493                tracing::info!("waiting for cleanup timeout");
 494                timeout.await;
 495                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 496
 497                app_state
 498                    .db
 499                    .delete_stale_channel_chat_participants(
 500                        &app_state.config.zed_environment,
 501                        server_id,
 502                    )
 503                    .await
 504                    .trace_err();
 505
 506                if let Some((room_ids, channel_ids)) = app_state
 507                    .db
 508                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 509                    .await
 510                    .trace_err()
 511                {
 512                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 513                    tracing::info!(
 514                        stale_channel_buffer_count = channel_ids.len(),
 515                        "retrieved stale channel buffers"
 516                    );
 517
 518                    for channel_id in channel_ids {
 519                        if let Some(refreshed_channel_buffer) = app_state
 520                            .db
 521                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 522                            .await
 523                            .trace_err()
 524                        {
 525                            for connection_id in refreshed_channel_buffer.connection_ids {
 526                                peer.send(
 527                                    connection_id,
 528                                    proto::UpdateChannelBufferCollaborators {
 529                                        channel_id: channel_id.to_proto(),
 530                                        collaborators: refreshed_channel_buffer
 531                                            .collaborators
 532                                            .clone(),
 533                                    },
 534                                )
 535                                .trace_err();
 536                            }
 537                        }
 538                    }
 539
 540                    for room_id in room_ids {
 541                        let mut contacts_to_update = HashSet::default();
 542                        let mut canceled_calls_to_user_ids = Vec::new();
 543                        let mut livekit_room = String::new();
 544                        let mut delete_livekit_room = false;
 545
 546                        if let Some(mut refreshed_room) = app_state
 547                            .db
 548                            .clear_stale_room_participants(room_id, server_id)
 549                            .await
 550                            .trace_err()
 551                        {
 552                            tracing::info!(
 553                                room_id = room_id.0,
 554                                new_participant_count = refreshed_room.room.participants.len(),
 555                                "refreshed room"
 556                            );
 557                            room_updated(&refreshed_room.room, &peer);
 558                            if let Some(channel) = refreshed_room.channel.as_ref() {
 559                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 560                            }
 561                            contacts_to_update
 562                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 563                            contacts_to_update
 564                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 565                            canceled_calls_to_user_ids =
 566                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 567                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 568                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 569                        }
 570
 571                        {
 572                            let pool = pool.lock();
 573                            for canceled_user_id in canceled_calls_to_user_ids {
 574                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 575                                    peer.send(
 576                                        connection_id,
 577                                        proto::CallCanceled {
 578                                            room_id: room_id.to_proto(),
 579                                        },
 580                                    )
 581                                    .trace_err();
 582                                }
 583                            }
 584                        }
 585
 586                        for user_id in contacts_to_update {
 587                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 588                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 589                            if let Some((busy, contacts)) = busy.zip(contacts) {
 590                                let pool = pool.lock();
 591                                let updated_contact = contact_for_user(user_id, busy, &pool);
 592                                for contact in contacts {
 593                                    if let db::Contact::Accepted {
 594                                        user_id: contact_user_id,
 595                                        ..
 596                                    } = contact
 597                                    {
 598                                        for contact_conn_id in
 599                                            pool.user_connection_ids(contact_user_id)
 600                                        {
 601                                            peer.send(
 602                                                contact_conn_id,
 603                                                proto::UpdateContacts {
 604                                                    contacts: vec![updated_contact.clone()],
 605                                                    remove_contacts: Default::default(),
 606                                                    incoming_requests: Default::default(),
 607                                                    remove_incoming_requests: Default::default(),
 608                                                    outgoing_requests: Default::default(),
 609                                                    remove_outgoing_requests: Default::default(),
 610                                                },
 611                                            )
 612                                            .trace_err();
 613                                        }
 614                                    }
 615                                }
 616                            }
 617                        }
 618
 619                        if let Some(live_kit) = livekit_client.as_ref() {
 620                            if delete_livekit_room {
 621                                live_kit.delete_room(livekit_room).await.trace_err();
 622                            }
 623                        }
 624                    }
 625                }
 626
 627                app_state
 628                    .db
 629                    .delete_stale_channel_chat_participants(
 630                        &app_state.config.zed_environment,
 631                        server_id,
 632                    )
 633                    .await
 634                    .trace_err();
 635
 636                app_state
 637                    .db
 638                    .clear_old_worktree_entries(server_id)
 639                    .await
 640                    .trace_err();
 641
 642                app_state
 643                    .db
 644                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 645                    .await
 646                    .trace_err();
 647            }
 648            .instrument(span),
 649        );
 650        Ok(())
 651    }
 652
 653    pub fn teardown(&self) {
 654        self.peer.teardown();
 655        self.connection_pool.lock().reset();
 656        let _ = self.teardown.send(true);
 657    }
 658
 659    #[cfg(test)]
 660    pub fn reset(&self, id: ServerId) {
 661        self.teardown();
 662        *self.id.lock() = id;
 663        self.peer.reset(id.0 as u32);
 664        let _ = self.teardown.send(false);
 665    }
 666
 667    #[cfg(test)]
 668    pub fn id(&self) -> ServerId {
 669        *self.id.lock()
 670    }
 671
 672    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 673    where
 674        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 675        Fut: 'static + Send + Future<Output = Result<()>>,
 676        M: EnvelopedMessage,
 677    {
 678        let prev_handler = self.handlers.insert(
 679            TypeId::of::<M>(),
 680            Box::new(move |envelope, session, span| {
 681                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 682                let received_at = envelope.received_at;
 683                tracing::info!("message received");
 684                let start_time = Instant::now();
 685                let future = (handler)(
 686                    *envelope,
 687                    MessageContext {
 688                        session,
 689                        span: span.clone(),
 690                    },
 691                );
 692                async move {
 693                    let result = future.await;
 694                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 695                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 696                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 697                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 698                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 699                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 700                    match result {
 701                        Err(error) => {
 702                            tracing::error!(?error, "error handling message")
 703                        }
 704                        Ok(()) => tracing::info!("finished handling message"),
 705                    }
 706                }
 707                .boxed()
 708            }),
 709        );
 710        if prev_handler.is_some() {
 711            panic!("registered a handler for the same message twice");
 712        }
 713        self
 714    }
 715
 716    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 717    where
 718        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 719        Fut: 'static + Send + Future<Output = Result<()>>,
 720        M: EnvelopedMessage,
 721    {
 722        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 723        self
 724    }
 725
 726    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 727    where
 728        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 729        Fut: Send + Future<Output = Result<()>>,
 730        M: RequestMessage,
 731    {
 732        let handler = Arc::new(handler);
 733        self.add_handler(move |envelope, session| {
 734            let receipt = envelope.receipt();
 735            let handler = handler.clone();
 736            async move {
 737                let peer = session.peer.clone();
 738                let responded = Arc::new(AtomicBool::default());
 739                let response = Response {
 740                    peer: peer.clone(),
 741                    responded: responded.clone(),
 742                    receipt,
 743                };
 744                match (handler)(envelope.payload, response, session).await {
 745                    Ok(()) => {
 746                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 747                            Ok(())
 748                        } else {
 749                            Err(anyhow!("handler did not send a response"))?
 750                        }
 751                    }
 752                    Err(error) => {
 753                        let proto_err = match &error {
 754                            Error::Internal(err) => err.to_proto(),
 755                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 756                        };
 757                        peer.respond_with_error(receipt, proto_err)?;
 758                        Err(error)
 759                    }
 760                }
 761            }
 762        })
 763    }
 764
 765    pub fn handle_connection(
 766        self: &Arc<Self>,
 767        connection: Connection,
 768        address: String,
 769        principal: Principal,
 770        zed_version: ZedVersion,
 771        release_channel: Option<String>,
 772        user_agent: Option<String>,
 773        geoip_country_code: Option<String>,
 774        system_id: Option<String>,
 775        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 776        executor: Executor,
 777        connection_guard: Option<ConnectionGuard>,
 778    ) -> impl Future<Output = ()> + use<> {
 779        let this = self.clone();
 780        let span = info_span!("handle connection", %address,
 781            connection_id=field::Empty,
 782            user_id=field::Empty,
 783            login=field::Empty,
 784            impersonator=field::Empty,
 785            user_agent=field::Empty,
 786            geoip_country_code=field::Empty,
 787            release_channel=field::Empty,
 788        );
 789        principal.update_span(&span);
 790        if let Some(user_agent) = user_agent {
 791            span.record("user_agent", user_agent);
 792        }
 793        if let Some(release_channel) = release_channel {
 794            span.record("release_channel", release_channel);
 795        }
 796
 797        if let Some(country_code) = geoip_country_code.as_ref() {
 798            span.record("geoip_country_code", country_code);
 799        }
 800
 801        let mut teardown = self.teardown.subscribe();
 802        async move {
 803            if *teardown.borrow() {
 804                tracing::error!("server is tearing down");
 805                return;
 806            }
 807
 808            let (connection_id, handle_io, mut incoming_rx) =
 809                this.peer.add_connection(connection, {
 810                    let executor = executor.clone();
 811                    move |duration| executor.sleep(duration)
 812                });
 813            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 814
 815            tracing::info!("connection opened");
 816
 817            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 818            let http_client = match ReqwestClient::user_agent(&user_agent) {
 819                Ok(http_client) => Arc::new(http_client),
 820                Err(error) => {
 821                    tracing::error!(?error, "failed to create HTTP client");
 822                    return;
 823                }
 824            };
 825
 826            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 827                |supermaven_admin_api_key| {
 828                    Arc::new(SupermavenAdminApi::new(
 829                        supermaven_admin_api_key.to_string(),
 830                        http_client.clone(),
 831                    ))
 832                },
 833            );
 834
 835            let session = Session {
 836                principal: principal.clone(),
 837                connection_id,
 838                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 839                peer: this.peer.clone(),
 840                connection_pool: this.connection_pool.clone(),
 841                app_state: this.app_state.clone(),
 842                geoip_country_code,
 843                system_id,
 844                _executor: executor.clone(),
 845                supermaven_client,
 846            };
 847
 848            if let Err(error) = this
 849                .send_initial_client_update(
 850                    connection_id,
 851                    zed_version,
 852                    send_connection_id,
 853                    &session,
 854                )
 855                .await
 856            {
 857                tracing::error!(?error, "failed to send initial client update");
 858                return;
 859            }
 860            drop(connection_guard);
 861
 862            let handle_io = handle_io.fuse();
 863            futures::pin_mut!(handle_io);
 864
 865            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 866            // This prevents deadlocks when e.g., client A performs a request to client B and
 867            // client B performs a request to client A. If both clients stop processing further
 868            // messages until their respective request completes, they won't have a chance to
 869            // respond to the other client's request and cause a deadlock.
 870            //
 871            // This arrangement ensures we will attempt to process earlier messages first, but fall
 872            // back to processing messages arrived later in the spirit of making progress.
 873            const MAX_CONCURRENT_HANDLERS: usize = 256;
 874            let mut foreground_message_handlers = FuturesUnordered::new();
 875            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 876            let get_concurrent_handlers = {
 877                let concurrent_handlers = concurrent_handlers.clone();
 878                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 879            };
 880            loop {
 881                let next_message = async {
 882                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 883                    let message = incoming_rx.next().await;
 884                    // Cache the concurrent_handlers here, so that we know what the
 885                    // queue looks like as each handler starts
 886                    (permit, message, get_concurrent_handlers())
 887                }
 888                .fuse();
 889                futures::pin_mut!(next_message);
 890                futures::select_biased! {
 891                    _ = teardown.changed().fuse() => return,
 892                    result = handle_io => {
 893                        if let Err(error) = result {
 894                            tracing::error!(?error, "error handling I/O");
 895                        }
 896                        break;
 897                    }
 898                    _ = foreground_message_handlers.next() => {}
 899                    next_message = next_message => {
 900                        let (permit, message, concurrent_handlers) = next_message;
 901                        if let Some(message) = message {
 902                            let type_name = message.payload_type_name();
 903                            // note: we copy all the fields from the parent span so we can query them in the logs.
 904                            // (https://github.com/tokio-rs/tracing/issues/2670).
 905                            let span = tracing::info_span!("receive message",
 906                                %connection_id,
 907                                %address,
 908                                type_name,
 909                                concurrent_handlers,
 910                                user_id=field::Empty,
 911                                login=field::Empty,
 912                                impersonator=field::Empty,
 913                                multi_lsp_query_request=field::Empty,
 914                                release_channel=field::Empty,
 915                                { TOTAL_DURATION_MS }=field::Empty,
 916                                { PROCESSING_DURATION_MS }=field::Empty,
 917                                { QUEUE_DURATION_MS }=field::Empty,
 918                                { HOST_WAITING_MS }=field::Empty
 919                            );
 920                            principal.update_span(&span);
 921                            let span_enter = span.enter();
 922                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 923                                let is_background = message.is_background();
 924                                let handle_message = (handler)(message, session.clone(), span.clone());
 925                                drop(span_enter);
 926
 927                                let handle_message = async move {
 928                                    handle_message.await;
 929                                    drop(permit);
 930                                }.instrument(span);
 931                                if is_background {
 932                                    executor.spawn_detached(handle_message);
 933                                } else {
 934                                    foreground_message_handlers.push(handle_message);
 935                                }
 936                            } else {
 937                                tracing::error!("no message handler");
 938                            }
 939                        } else {
 940                            tracing::info!("connection closed");
 941                            break;
 942                        }
 943                    }
 944                }
 945            }
 946
 947            drop(foreground_message_handlers);
 948            let concurrent_handlers = get_concurrent_handlers();
 949            tracing::info!(concurrent_handlers, "signing out");
 950            if let Err(error) = connection_lost(session, teardown, executor).await {
 951                tracing::error!(?error, "error signing out");
 952            }
 953        }
 954        .instrument(span)
 955    }
 956
 957    async fn send_initial_client_update(
 958        &self,
 959        connection_id: ConnectionId,
 960        zed_version: ZedVersion,
 961        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 962        session: &Session,
 963    ) -> Result<()> {
 964        self.peer.send(
 965            connection_id,
 966            proto::Hello {
 967                peer_id: Some(connection_id.into()),
 968            },
 969        )?;
 970        tracing::info!("sent hello message");
 971        if let Some(send_connection_id) = send_connection_id.take() {
 972            let _ = send_connection_id.send(connection_id);
 973        }
 974
 975        match &session.principal {
 976            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 977                if !user.connected_once {
 978                    self.peer.send(connection_id, proto::ShowContacts {})?;
 979                    self.app_state
 980                        .db
 981                        .set_user_connected_once(user.id, true)
 982                        .await?;
 983                }
 984
 985                let contacts = self.app_state.db.get_contacts(user.id).await?;
 986
 987                {
 988                    let mut pool = self.connection_pool.lock();
 989                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 990                    self.peer.send(
 991                        connection_id,
 992                        build_initial_contacts_update(contacts, &pool),
 993                    )?;
 994                }
 995
 996                if should_auto_subscribe_to_channels(zed_version) {
 997                    subscribe_user_to_channels(user.id, session).await?;
 998                }
 999
1000                if let Some(incoming_call) =
1001                    self.app_state.db.incoming_call_for_user(user.id).await?
1002                {
1003                    self.peer.send(connection_id, incoming_call)?;
1004                }
1005
1006                update_user_contacts(user.id, session).await?;
1007            }
1008        }
1009
1010        Ok(())
1011    }
1012
1013    pub async fn invite_code_redeemed(
1014        self: &Arc<Self>,
1015        inviter_id: UserId,
1016        invitee_id: UserId,
1017    ) -> Result<()> {
1018        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1019            if let Some(code) = &user.invite_code {
1020                let pool = self.connection_pool.lock();
1021                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1022                for connection_id in pool.user_connection_ids(inviter_id) {
1023                    self.peer.send(
1024                        connection_id,
1025                        proto::UpdateContacts {
1026                            contacts: vec![invitee_contact.clone()],
1027                            ..Default::default()
1028                        },
1029                    )?;
1030                    self.peer.send(
1031                        connection_id,
1032                        proto::UpdateInviteInfo {
1033                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1034                            count: user.invite_count as u32,
1035                        },
1036                    )?;
1037                }
1038            }
1039        }
1040        Ok(())
1041    }
1042
1043    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1044        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1045            if let Some(invite_code) = &user.invite_code {
1046                let pool = self.connection_pool.lock();
1047                for connection_id in pool.user_connection_ids(user_id) {
1048                    self.peer.send(
1049                        connection_id,
1050                        proto::UpdateInviteInfo {
1051                            url: format!(
1052                                "{}{}",
1053                                self.app_state.config.invite_link_prefix, invite_code
1054                            ),
1055                            count: user.invite_count as u32,
1056                        },
1057                    )?;
1058                }
1059            }
1060        }
1061        Ok(())
1062    }
1063
1064    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1065        ServerSnapshot {
1066            connection_pool: ConnectionPoolGuard {
1067                guard: self.connection_pool.lock(),
1068                _not_send: PhantomData,
1069            },
1070            peer: &self.peer,
1071        }
1072    }
1073}
1074
1075impl Deref for ConnectionPoolGuard<'_> {
1076    type Target = ConnectionPool;
1077
1078    fn deref(&self) -> &Self::Target {
1079        &self.guard
1080    }
1081}
1082
1083impl DerefMut for ConnectionPoolGuard<'_> {
1084    fn deref_mut(&mut self) -> &mut Self::Target {
1085        &mut self.guard
1086    }
1087}
1088
1089impl Drop for ConnectionPoolGuard<'_> {
1090    fn drop(&mut self) {
1091        #[cfg(test)]
1092        self.check_invariants();
1093    }
1094}
1095
1096fn broadcast<F>(
1097    sender_id: Option<ConnectionId>,
1098    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1099    mut f: F,
1100) where
1101    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1102{
1103    for receiver_id in receiver_ids {
1104        if Some(receiver_id) != sender_id {
1105            if let Err(error) = f(receiver_id) {
1106                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1107            }
1108        }
1109    }
1110}
1111
1112pub struct ProtocolVersion(u32);
1113
1114impl Header for ProtocolVersion {
1115    fn name() -> &'static HeaderName {
1116        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1117        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1118    }
1119
1120    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1121    where
1122        Self: Sized,
1123        I: Iterator<Item = &'i axum::http::HeaderValue>,
1124    {
1125        let version = values
1126            .next()
1127            .ok_or_else(axum::headers::Error::invalid)?
1128            .to_str()
1129            .map_err(|_| axum::headers::Error::invalid())?
1130            .parse()
1131            .map_err(|_| axum::headers::Error::invalid())?;
1132        Ok(Self(version))
1133    }
1134
1135    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1136        values.extend([self.0.to_string().parse().unwrap()]);
1137    }
1138}
1139
1140pub struct AppVersionHeader(SemanticVersion);
1141impl Header for AppVersionHeader {
1142    fn name() -> &'static HeaderName {
1143        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1144        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1145    }
1146
1147    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1148    where
1149        Self: Sized,
1150        I: Iterator<Item = &'i axum::http::HeaderValue>,
1151    {
1152        let version = values
1153            .next()
1154            .ok_or_else(axum::headers::Error::invalid)?
1155            .to_str()
1156            .map_err(|_| axum::headers::Error::invalid())?
1157            .parse()
1158            .map_err(|_| axum::headers::Error::invalid())?;
1159        Ok(Self(version))
1160    }
1161
1162    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1163        values.extend([self.0.to_string().parse().unwrap()]);
1164    }
1165}
1166
1167#[derive(Debug)]
1168pub struct ReleaseChannelHeader(String);
1169
1170impl Header for ReleaseChannelHeader {
1171    fn name() -> &'static HeaderName {
1172        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1173        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1174    }
1175
1176    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1177    where
1178        Self: Sized,
1179        I: Iterator<Item = &'i axum::http::HeaderValue>,
1180    {
1181        Ok(Self(
1182            values
1183                .next()
1184                .ok_or_else(axum::headers::Error::invalid)?
1185                .to_str()
1186                .map_err(|_| axum::headers::Error::invalid())?
1187                .to_owned(),
1188        ))
1189    }
1190
1191    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1192        values.extend([self.0.parse().unwrap()]);
1193    }
1194}
1195
1196pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1197    Router::new()
1198        .route("/rpc", get(handle_websocket_request))
1199        .layer(
1200            ServiceBuilder::new()
1201                .layer(Extension(server.app_state.clone()))
1202                .layer(middleware::from_fn(auth::validate_header)),
1203        )
1204        .route("/metrics", get(handle_metrics))
1205        .layer(Extension(server))
1206}
1207
1208pub async fn handle_websocket_request(
1209    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1210    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1211    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1212    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1213    Extension(server): Extension<Arc<Server>>,
1214    Extension(principal): Extension<Principal>,
1215    user_agent: Option<TypedHeader<UserAgent>>,
1216    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1217    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1218    ws: WebSocketUpgrade,
1219) -> axum::response::Response {
1220    if protocol_version != rpc::PROTOCOL_VERSION {
1221        return (
1222            StatusCode::UPGRADE_REQUIRED,
1223            "client must be upgraded".to_string(),
1224        )
1225            .into_response();
1226    }
1227
1228    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1229        return (
1230            StatusCode::UPGRADE_REQUIRED,
1231            "no version header found".to_string(),
1232        )
1233            .into_response();
1234    };
1235
1236    let release_channel = release_channel_header.map(|header| header.0.0);
1237
1238    if !version.can_collaborate() {
1239        return (
1240            StatusCode::UPGRADE_REQUIRED,
1241            "client must be upgraded".to_string(),
1242        )
1243            .into_response();
1244    }
1245
1246    let socket_address = socket_address.to_string();
1247
1248    // Acquire connection guard before WebSocket upgrade
1249    let connection_guard = match ConnectionGuard::try_acquire() {
1250        Ok(guard) => guard,
1251        Err(()) => {
1252            return (
1253                StatusCode::SERVICE_UNAVAILABLE,
1254                "Too many concurrent connections",
1255            )
1256                .into_response();
1257        }
1258    };
1259
1260    ws.on_upgrade(move |socket| {
1261        let socket = socket
1262            .map_ok(to_tungstenite_message)
1263            .err_into()
1264            .with(|message| async move { to_axum_message(message) });
1265        let connection = Connection::new(Box::pin(socket));
1266        async move {
1267            server
1268                .handle_connection(
1269                    connection,
1270                    socket_address,
1271                    principal,
1272                    version,
1273                    release_channel,
1274                    user_agent.map(|header| header.to_string()),
1275                    country_code_header.map(|header| header.to_string()),
1276                    system_id_header.map(|header| header.to_string()),
1277                    None,
1278                    Executor::Production,
1279                    Some(connection_guard),
1280                )
1281                .await;
1282        }
1283    })
1284}
1285
1286pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1287    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1288    let connections_metric = CONNECTIONS_METRIC
1289        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1290
1291    let connections = server
1292        .connection_pool
1293        .lock()
1294        .connections()
1295        .filter(|connection| !connection.admin)
1296        .count();
1297    connections_metric.set(connections as _);
1298
1299    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1300    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1301        register_int_gauge!(
1302            "shared_projects",
1303            "number of open projects with one or more guests"
1304        )
1305        .unwrap()
1306    });
1307
1308    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1309    shared_projects_metric.set(shared_projects as _);
1310
1311    let encoder = prometheus::TextEncoder::new();
1312    let metric_families = prometheus::gather();
1313    let encoded_metrics = encoder
1314        .encode_to_string(&metric_families)
1315        .map_err(|err| anyhow!("{err}"))?;
1316    Ok(encoded_metrics)
1317}
1318
1319#[instrument(err, skip(executor))]
1320async fn connection_lost(
1321    session: Session,
1322    mut teardown: watch::Receiver<bool>,
1323    executor: Executor,
1324) -> Result<()> {
1325    session.peer.disconnect(session.connection_id);
1326    session
1327        .connection_pool()
1328        .await
1329        .remove_connection(session.connection_id)?;
1330
1331    session
1332        .db()
1333        .await
1334        .connection_lost(session.connection_id)
1335        .await
1336        .trace_err();
1337
1338    futures::select_biased! {
1339        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1340
1341            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1342            leave_room_for_session(&session, session.connection_id).await.trace_err();
1343            leave_channel_buffers_for_session(&session)
1344                .await
1345                .trace_err();
1346
1347            if !session
1348                .connection_pool()
1349                .await
1350                .is_user_online(session.user_id())
1351            {
1352                let db = session.db().await;
1353                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1354                    room_updated(&room, &session.peer);
1355                }
1356            }
1357
1358            update_user_contacts(session.user_id(), &session).await?;
1359        },
1360        _ = teardown.changed().fuse() => {}
1361    }
1362
1363    Ok(())
1364}
1365
1366/// Acknowledges a ping from a client, used to keep the connection alive.
1367async fn ping(
1368    _: proto::Ping,
1369    response: Response<proto::Ping>,
1370    _session: MessageContext,
1371) -> Result<()> {
1372    response.send(proto::Ack {})?;
1373    Ok(())
1374}
1375
1376/// Creates a new room for calling (outside of channels)
1377async fn create_room(
1378    _request: proto::CreateRoom,
1379    response: Response<proto::CreateRoom>,
1380    session: MessageContext,
1381) -> Result<()> {
1382    let livekit_room = nanoid::nanoid!(30);
1383
1384    let live_kit_connection_info = util::maybe!(async {
1385        let live_kit = session.app_state.livekit_client.as_ref();
1386        let live_kit = live_kit?;
1387        let user_id = session.user_id().to_string();
1388
1389        let token = live_kit
1390            .room_token(&livekit_room, &user_id.to_string())
1391            .trace_err()?;
1392
1393        Some(proto::LiveKitConnectionInfo {
1394            server_url: live_kit.url().into(),
1395            token,
1396            can_publish: true,
1397        })
1398    })
1399    .await;
1400
1401    let room = session
1402        .db()
1403        .await
1404        .create_room(session.user_id(), session.connection_id, &livekit_room)
1405        .await?;
1406
1407    response.send(proto::CreateRoomResponse {
1408        room: Some(room.clone()),
1409        live_kit_connection_info,
1410    })?;
1411
1412    update_user_contacts(session.user_id(), &session).await?;
1413    Ok(())
1414}
1415
1416/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1417async fn join_room(
1418    request: proto::JoinRoom,
1419    response: Response<proto::JoinRoom>,
1420    session: MessageContext,
1421) -> Result<()> {
1422    let room_id = RoomId::from_proto(request.id);
1423
1424    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1425
1426    if let Some(channel_id) = channel_id {
1427        return join_channel_internal(channel_id, Box::new(response), session).await;
1428    }
1429
1430    let joined_room = {
1431        let room = session
1432            .db()
1433            .await
1434            .join_room(room_id, session.user_id(), session.connection_id)
1435            .await?;
1436        room_updated(&room.room, &session.peer);
1437        room.into_inner()
1438    };
1439
1440    for connection_id in session
1441        .connection_pool()
1442        .await
1443        .user_connection_ids(session.user_id())
1444    {
1445        session
1446            .peer
1447            .send(
1448                connection_id,
1449                proto::CallCanceled {
1450                    room_id: room_id.to_proto(),
1451                },
1452            )
1453            .trace_err();
1454    }
1455
1456    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1457    {
1458        live_kit
1459            .room_token(
1460                &joined_room.room.livekit_room,
1461                &session.user_id().to_string(),
1462            )
1463            .trace_err()
1464            .map(|token| proto::LiveKitConnectionInfo {
1465                server_url: live_kit.url().into(),
1466                token,
1467                can_publish: true,
1468            })
1469    } else {
1470        None
1471    };
1472
1473    response.send(proto::JoinRoomResponse {
1474        room: Some(joined_room.room),
1475        channel_id: None,
1476        live_kit_connection_info,
1477    })?;
1478
1479    update_user_contacts(session.user_id(), &session).await?;
1480    Ok(())
1481}
1482
1483/// Rejoin room is used to reconnect to a room after connection errors.
1484async fn rejoin_room(
1485    request: proto::RejoinRoom,
1486    response: Response<proto::RejoinRoom>,
1487    session: MessageContext,
1488) -> Result<()> {
1489    let room;
1490    let channel;
1491    {
1492        let mut rejoined_room = session
1493            .db()
1494            .await
1495            .rejoin_room(request, session.user_id(), session.connection_id)
1496            .await?;
1497
1498        response.send(proto::RejoinRoomResponse {
1499            room: Some(rejoined_room.room.clone()),
1500            reshared_projects: rejoined_room
1501                .reshared_projects
1502                .iter()
1503                .map(|project| proto::ResharedProject {
1504                    id: project.id.to_proto(),
1505                    collaborators: project
1506                        .collaborators
1507                        .iter()
1508                        .map(|collaborator| collaborator.to_proto())
1509                        .collect(),
1510                })
1511                .collect(),
1512            rejoined_projects: rejoined_room
1513                .rejoined_projects
1514                .iter()
1515                .map(|rejoined_project| rejoined_project.to_proto())
1516                .collect(),
1517        })?;
1518        room_updated(&rejoined_room.room, &session.peer);
1519
1520        for project in &rejoined_room.reshared_projects {
1521            for collaborator in &project.collaborators {
1522                session
1523                    .peer
1524                    .send(
1525                        collaborator.connection_id,
1526                        proto::UpdateProjectCollaborator {
1527                            project_id: project.id.to_proto(),
1528                            old_peer_id: Some(project.old_connection_id.into()),
1529                            new_peer_id: Some(session.connection_id.into()),
1530                        },
1531                    )
1532                    .trace_err();
1533            }
1534
1535            broadcast(
1536                Some(session.connection_id),
1537                project
1538                    .collaborators
1539                    .iter()
1540                    .map(|collaborator| collaborator.connection_id),
1541                |connection_id| {
1542                    session.peer.forward_send(
1543                        session.connection_id,
1544                        connection_id,
1545                        proto::UpdateProject {
1546                            project_id: project.id.to_proto(),
1547                            worktrees: project.worktrees.clone(),
1548                        },
1549                    )
1550                },
1551            );
1552        }
1553
1554        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1555
1556        let rejoined_room = rejoined_room.into_inner();
1557
1558        room = rejoined_room.room;
1559        channel = rejoined_room.channel;
1560    }
1561
1562    if let Some(channel) = channel {
1563        channel_updated(
1564            &channel,
1565            &room,
1566            &session.peer,
1567            &*session.connection_pool().await,
1568        );
1569    }
1570
1571    update_user_contacts(session.user_id(), &session).await?;
1572    Ok(())
1573}
1574
1575fn notify_rejoined_projects(
1576    rejoined_projects: &mut Vec<RejoinedProject>,
1577    session: &Session,
1578) -> Result<()> {
1579    for project in rejoined_projects.iter() {
1580        for collaborator in &project.collaborators {
1581            session
1582                .peer
1583                .send(
1584                    collaborator.connection_id,
1585                    proto::UpdateProjectCollaborator {
1586                        project_id: project.id.to_proto(),
1587                        old_peer_id: Some(project.old_connection_id.into()),
1588                        new_peer_id: Some(session.connection_id.into()),
1589                    },
1590                )
1591                .trace_err();
1592        }
1593    }
1594
1595    for project in rejoined_projects {
1596        for worktree in mem::take(&mut project.worktrees) {
1597            // Stream this worktree's entries.
1598            let message = proto::UpdateWorktree {
1599                project_id: project.id.to_proto(),
1600                worktree_id: worktree.id,
1601                abs_path: worktree.abs_path.clone(),
1602                root_name: worktree.root_name,
1603                updated_entries: worktree.updated_entries,
1604                removed_entries: worktree.removed_entries,
1605                scan_id: worktree.scan_id,
1606                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1607                updated_repositories: worktree.updated_repositories,
1608                removed_repositories: worktree.removed_repositories,
1609            };
1610            for update in proto::split_worktree_update(message) {
1611                session.peer.send(session.connection_id, update)?;
1612            }
1613
1614            // Stream this worktree's diagnostics.
1615            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1616            if let Some(summary) = worktree_diagnostics.next() {
1617                let message = proto::UpdateDiagnosticSummary {
1618                    project_id: project.id.to_proto(),
1619                    worktree_id: worktree.id,
1620                    summary: Some(summary),
1621                    more_summaries: worktree_diagnostics.collect(),
1622                };
1623                session.peer.send(session.connection_id, message)?;
1624            }
1625
1626            for settings_file in worktree.settings_files {
1627                session.peer.send(
1628                    session.connection_id,
1629                    proto::UpdateWorktreeSettings {
1630                        project_id: project.id.to_proto(),
1631                        worktree_id: worktree.id,
1632                        path: settings_file.path,
1633                        content: Some(settings_file.content),
1634                        kind: Some(settings_file.kind.to_proto().into()),
1635                    },
1636                )?;
1637            }
1638        }
1639
1640        for repository in mem::take(&mut project.updated_repositories) {
1641            for update in split_repository_update(repository) {
1642                session.peer.send(session.connection_id, update)?;
1643            }
1644        }
1645
1646        for id in mem::take(&mut project.removed_repositories) {
1647            session.peer.send(
1648                session.connection_id,
1649                proto::RemoveRepository {
1650                    project_id: project.id.to_proto(),
1651                    id,
1652                },
1653            )?;
1654        }
1655    }
1656
1657    Ok(())
1658}
1659
1660/// leave room disconnects from the room.
1661async fn leave_room(
1662    _: proto::LeaveRoom,
1663    response: Response<proto::LeaveRoom>,
1664    session: MessageContext,
1665) -> Result<()> {
1666    leave_room_for_session(&session, session.connection_id).await?;
1667    response.send(proto::Ack {})?;
1668    Ok(())
1669}
1670
1671/// Updates the permissions of someone else in the room.
1672async fn set_room_participant_role(
1673    request: proto::SetRoomParticipantRole,
1674    response: Response<proto::SetRoomParticipantRole>,
1675    session: MessageContext,
1676) -> Result<()> {
1677    let user_id = UserId::from_proto(request.user_id);
1678    let role = ChannelRole::from(request.role());
1679
1680    let (livekit_room, can_publish) = {
1681        let room = session
1682            .db()
1683            .await
1684            .set_room_participant_role(
1685                session.user_id(),
1686                RoomId::from_proto(request.room_id),
1687                user_id,
1688                role,
1689            )
1690            .await?;
1691
1692        let livekit_room = room.livekit_room.clone();
1693        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1694        room_updated(&room, &session.peer);
1695        (livekit_room, can_publish)
1696    };
1697
1698    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1699        live_kit
1700            .update_participant(
1701                livekit_room.clone(),
1702                request.user_id.to_string(),
1703                livekit_api::proto::ParticipantPermission {
1704                    can_subscribe: true,
1705                    can_publish,
1706                    can_publish_data: can_publish,
1707                    hidden: false,
1708                    recorder: false,
1709                },
1710            )
1711            .await
1712            .trace_err();
1713    }
1714
1715    response.send(proto::Ack {})?;
1716    Ok(())
1717}
1718
1719/// Call someone else into the current room
1720async fn call(
1721    request: proto::Call,
1722    response: Response<proto::Call>,
1723    session: MessageContext,
1724) -> Result<()> {
1725    let room_id = RoomId::from_proto(request.room_id);
1726    let calling_user_id = session.user_id();
1727    let calling_connection_id = session.connection_id;
1728    let called_user_id = UserId::from_proto(request.called_user_id);
1729    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1730    if !session
1731        .db()
1732        .await
1733        .has_contact(calling_user_id, called_user_id)
1734        .await?
1735    {
1736        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1737    }
1738
1739    let incoming_call = {
1740        let (room, incoming_call) = &mut *session
1741            .db()
1742            .await
1743            .call(
1744                room_id,
1745                calling_user_id,
1746                calling_connection_id,
1747                called_user_id,
1748                initial_project_id,
1749            )
1750            .await?;
1751        room_updated(room, &session.peer);
1752        mem::take(incoming_call)
1753    };
1754    update_user_contacts(called_user_id, &session).await?;
1755
1756    let mut calls = session
1757        .connection_pool()
1758        .await
1759        .user_connection_ids(called_user_id)
1760        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1761        .collect::<FuturesUnordered<_>>();
1762
1763    while let Some(call_response) = calls.next().await {
1764        match call_response.as_ref() {
1765            Ok(_) => {
1766                response.send(proto::Ack {})?;
1767                return Ok(());
1768            }
1769            Err(_) => {
1770                call_response.trace_err();
1771            }
1772        }
1773    }
1774
1775    {
1776        let room = session
1777            .db()
1778            .await
1779            .call_failed(room_id, called_user_id)
1780            .await?;
1781        room_updated(&room, &session.peer);
1782    }
1783    update_user_contacts(called_user_id, &session).await?;
1784
1785    Err(anyhow!("failed to ring user"))?
1786}
1787
1788/// Cancel an outgoing call.
1789async fn cancel_call(
1790    request: proto::CancelCall,
1791    response: Response<proto::CancelCall>,
1792    session: MessageContext,
1793) -> Result<()> {
1794    let called_user_id = UserId::from_proto(request.called_user_id);
1795    let room_id = RoomId::from_proto(request.room_id);
1796    {
1797        let room = session
1798            .db()
1799            .await
1800            .cancel_call(room_id, session.connection_id, called_user_id)
1801            .await?;
1802        room_updated(&room, &session.peer);
1803    }
1804
1805    for connection_id in session
1806        .connection_pool()
1807        .await
1808        .user_connection_ids(called_user_id)
1809    {
1810        session
1811            .peer
1812            .send(
1813                connection_id,
1814                proto::CallCanceled {
1815                    room_id: room_id.to_proto(),
1816                },
1817            )
1818            .trace_err();
1819    }
1820    response.send(proto::Ack {})?;
1821
1822    update_user_contacts(called_user_id, &session).await?;
1823    Ok(())
1824}
1825
1826/// Decline an incoming call.
1827async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1828    let room_id = RoomId::from_proto(message.room_id);
1829    {
1830        let room = session
1831            .db()
1832            .await
1833            .decline_call(Some(room_id), session.user_id())
1834            .await?
1835            .context("declining call")?;
1836        room_updated(&room, &session.peer);
1837    }
1838
1839    for connection_id in session
1840        .connection_pool()
1841        .await
1842        .user_connection_ids(session.user_id())
1843    {
1844        session
1845            .peer
1846            .send(
1847                connection_id,
1848                proto::CallCanceled {
1849                    room_id: room_id.to_proto(),
1850                },
1851            )
1852            .trace_err();
1853    }
1854    update_user_contacts(session.user_id(), &session).await?;
1855    Ok(())
1856}
1857
1858/// Updates other participants in the room with your current location.
1859async fn update_participant_location(
1860    request: proto::UpdateParticipantLocation,
1861    response: Response<proto::UpdateParticipantLocation>,
1862    session: MessageContext,
1863) -> Result<()> {
1864    let room_id = RoomId::from_proto(request.room_id);
1865    let location = request.location.context("invalid location")?;
1866
1867    let db = session.db().await;
1868    let room = db
1869        .update_room_participant_location(room_id, session.connection_id, location)
1870        .await?;
1871
1872    room_updated(&room, &session.peer);
1873    response.send(proto::Ack {})?;
1874    Ok(())
1875}
1876
1877/// Share a project into the room.
1878async fn share_project(
1879    request: proto::ShareProject,
1880    response: Response<proto::ShareProject>,
1881    session: MessageContext,
1882) -> Result<()> {
1883    let (project_id, room) = &*session
1884        .db()
1885        .await
1886        .share_project(
1887            RoomId::from_proto(request.room_id),
1888            session.connection_id,
1889            &request.worktrees,
1890            request.is_ssh_project,
1891        )
1892        .await?;
1893    response.send(proto::ShareProjectResponse {
1894        project_id: project_id.to_proto(),
1895    })?;
1896    room_updated(room, &session.peer);
1897
1898    Ok(())
1899}
1900
1901/// Unshare a project from the room.
1902async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1903    let project_id = ProjectId::from_proto(message.project_id);
1904    unshare_project_internal(project_id, session.connection_id, &session).await
1905}
1906
1907async fn unshare_project_internal(
1908    project_id: ProjectId,
1909    connection_id: ConnectionId,
1910    session: &Session,
1911) -> Result<()> {
1912    let delete = {
1913        let room_guard = session
1914            .db()
1915            .await
1916            .unshare_project(project_id, connection_id)
1917            .await?;
1918
1919        let (delete, room, guest_connection_ids) = &*room_guard;
1920
1921        let message = proto::UnshareProject {
1922            project_id: project_id.to_proto(),
1923        };
1924
1925        broadcast(
1926            Some(connection_id),
1927            guest_connection_ids.iter().copied(),
1928            |conn_id| session.peer.send(conn_id, message.clone()),
1929        );
1930        if let Some(room) = room {
1931            room_updated(room, &session.peer);
1932        }
1933
1934        *delete
1935    };
1936
1937    if delete {
1938        let db = session.db().await;
1939        db.delete_project(project_id).await?;
1940    }
1941
1942    Ok(())
1943}
1944
1945/// Join someone elses shared project.
1946async fn join_project(
1947    request: proto::JoinProject,
1948    response: Response<proto::JoinProject>,
1949    session: MessageContext,
1950) -> Result<()> {
1951    let project_id = ProjectId::from_proto(request.project_id);
1952
1953    tracing::info!(%project_id, "join project");
1954
1955    let db = session.db().await;
1956    let (project, replica_id) = &mut *db
1957        .join_project(
1958            project_id,
1959            session.connection_id,
1960            session.user_id(),
1961            request.committer_name.clone(),
1962            request.committer_email.clone(),
1963        )
1964        .await?;
1965    drop(db);
1966    tracing::info!(%project_id, "join remote project");
1967    let collaborators = project
1968        .collaborators
1969        .iter()
1970        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1971        .map(|collaborator| collaborator.to_proto())
1972        .collect::<Vec<_>>();
1973    let project_id = project.id;
1974    let guest_user_id = session.user_id();
1975
1976    let worktrees = project
1977        .worktrees
1978        .iter()
1979        .map(|(id, worktree)| proto::WorktreeMetadata {
1980            id: *id,
1981            root_name: worktree.root_name.clone(),
1982            visible: worktree.visible,
1983            abs_path: worktree.abs_path.clone(),
1984        })
1985        .collect::<Vec<_>>();
1986
1987    let add_project_collaborator = proto::AddProjectCollaborator {
1988        project_id: project_id.to_proto(),
1989        collaborator: Some(proto::Collaborator {
1990            peer_id: Some(session.connection_id.into()),
1991            replica_id: replica_id.0 as u32,
1992            user_id: guest_user_id.to_proto(),
1993            is_host: false,
1994            committer_name: request.committer_name.clone(),
1995            committer_email: request.committer_email.clone(),
1996        }),
1997    };
1998
1999    for collaborator in &collaborators {
2000        session
2001            .peer
2002            .send(
2003                collaborator.peer_id.unwrap().into(),
2004                add_project_collaborator.clone(),
2005            )
2006            .trace_err();
2007    }
2008
2009    // First, we send the metadata associated with each worktree.
2010    let (language_servers, language_server_capabilities) = project
2011        .language_servers
2012        .clone()
2013        .into_iter()
2014        .map(|server| (server.server, server.capabilities))
2015        .unzip();
2016    response.send(proto::JoinProjectResponse {
2017        project_id: project.id.0 as u64,
2018        worktrees: worktrees.clone(),
2019        replica_id: replica_id.0 as u32,
2020        collaborators: collaborators.clone(),
2021        language_servers,
2022        language_server_capabilities,
2023        role: project.role.into(),
2024    })?;
2025
2026    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2027        // Stream this worktree's entries.
2028        let message = proto::UpdateWorktree {
2029            project_id: project_id.to_proto(),
2030            worktree_id,
2031            abs_path: worktree.abs_path.clone(),
2032            root_name: worktree.root_name,
2033            updated_entries: worktree.entries,
2034            removed_entries: Default::default(),
2035            scan_id: worktree.scan_id,
2036            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2037            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2038            removed_repositories: Default::default(),
2039        };
2040        for update in proto::split_worktree_update(message) {
2041            session.peer.send(session.connection_id, update.clone())?;
2042        }
2043
2044        // Stream this worktree's diagnostics.
2045        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2046        if let Some(summary) = worktree_diagnostics.next() {
2047            let message = proto::UpdateDiagnosticSummary {
2048                project_id: project.id.to_proto(),
2049                worktree_id: worktree.id,
2050                summary: Some(summary),
2051                more_summaries: worktree_diagnostics.collect(),
2052            };
2053            session.peer.send(session.connection_id, message)?;
2054        }
2055
2056        for settings_file in worktree.settings_files {
2057            session.peer.send(
2058                session.connection_id,
2059                proto::UpdateWorktreeSettings {
2060                    project_id: project_id.to_proto(),
2061                    worktree_id: worktree.id,
2062                    path: settings_file.path,
2063                    content: Some(settings_file.content),
2064                    kind: Some(settings_file.kind.to_proto() as i32),
2065                },
2066            )?;
2067        }
2068    }
2069
2070    for repository in mem::take(&mut project.repositories) {
2071        for update in split_repository_update(repository) {
2072            session.peer.send(session.connection_id, update)?;
2073        }
2074    }
2075
2076    for language_server in &project.language_servers {
2077        session.peer.send(
2078            session.connection_id,
2079            proto::UpdateLanguageServer {
2080                project_id: project_id.to_proto(),
2081                server_name: Some(language_server.server.name.clone()),
2082                language_server_id: language_server.server.id,
2083                variant: Some(
2084                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2085                        proto::LspDiskBasedDiagnosticsUpdated {},
2086                    ),
2087                ),
2088            },
2089        )?;
2090    }
2091
2092    Ok(())
2093}
2094
2095/// Leave someone elses shared project.
2096async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2097    let sender_id = session.connection_id;
2098    let project_id = ProjectId::from_proto(request.project_id);
2099    let db = session.db().await;
2100
2101    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2102    tracing::info!(
2103        %project_id,
2104        "leave project"
2105    );
2106
2107    project_left(project, &session);
2108    if let Some(room) = room {
2109        room_updated(room, &session.peer);
2110    }
2111
2112    Ok(())
2113}
2114
2115/// Updates other participants with changes to the project
2116async fn update_project(
2117    request: proto::UpdateProject,
2118    response: Response<proto::UpdateProject>,
2119    session: MessageContext,
2120) -> Result<()> {
2121    let project_id = ProjectId::from_proto(request.project_id);
2122    let (room, guest_connection_ids) = &*session
2123        .db()
2124        .await
2125        .update_project(project_id, session.connection_id, &request.worktrees)
2126        .await?;
2127    broadcast(
2128        Some(session.connection_id),
2129        guest_connection_ids.iter().copied(),
2130        |connection_id| {
2131            session
2132                .peer
2133                .forward_send(session.connection_id, connection_id, request.clone())
2134        },
2135    );
2136    if let Some(room) = room {
2137        room_updated(room, &session.peer);
2138    }
2139    response.send(proto::Ack {})?;
2140
2141    Ok(())
2142}
2143
2144/// Updates other participants with changes to the worktree
2145async fn update_worktree(
2146    request: proto::UpdateWorktree,
2147    response: Response<proto::UpdateWorktree>,
2148    session: MessageContext,
2149) -> Result<()> {
2150    let guest_connection_ids = session
2151        .db()
2152        .await
2153        .update_worktree(&request, session.connection_id)
2154        .await?;
2155
2156    broadcast(
2157        Some(session.connection_id),
2158        guest_connection_ids.iter().copied(),
2159        |connection_id| {
2160            session
2161                .peer
2162                .forward_send(session.connection_id, connection_id, request.clone())
2163        },
2164    );
2165    response.send(proto::Ack {})?;
2166    Ok(())
2167}
2168
2169async fn update_repository(
2170    request: proto::UpdateRepository,
2171    response: Response<proto::UpdateRepository>,
2172    session: MessageContext,
2173) -> Result<()> {
2174    let guest_connection_ids = session
2175        .db()
2176        .await
2177        .update_repository(&request, session.connection_id)
2178        .await?;
2179
2180    broadcast(
2181        Some(session.connection_id),
2182        guest_connection_ids.iter().copied(),
2183        |connection_id| {
2184            session
2185                .peer
2186                .forward_send(session.connection_id, connection_id, request.clone())
2187        },
2188    );
2189    response.send(proto::Ack {})?;
2190    Ok(())
2191}
2192
2193async fn remove_repository(
2194    request: proto::RemoveRepository,
2195    response: Response<proto::RemoveRepository>,
2196    session: MessageContext,
2197) -> Result<()> {
2198    let guest_connection_ids = session
2199        .db()
2200        .await
2201        .remove_repository(&request, session.connection_id)
2202        .await?;
2203
2204    broadcast(
2205        Some(session.connection_id),
2206        guest_connection_ids.iter().copied(),
2207        |connection_id| {
2208            session
2209                .peer
2210                .forward_send(session.connection_id, connection_id, request.clone())
2211        },
2212    );
2213    response.send(proto::Ack {})?;
2214    Ok(())
2215}
2216
2217/// Updates other participants with changes to the diagnostics
2218async fn update_diagnostic_summary(
2219    message: proto::UpdateDiagnosticSummary,
2220    session: MessageContext,
2221) -> Result<()> {
2222    let guest_connection_ids = session
2223        .db()
2224        .await
2225        .update_diagnostic_summary(&message, session.connection_id)
2226        .await?;
2227
2228    broadcast(
2229        Some(session.connection_id),
2230        guest_connection_ids.iter().copied(),
2231        |connection_id| {
2232            session
2233                .peer
2234                .forward_send(session.connection_id, connection_id, message.clone())
2235        },
2236    );
2237
2238    Ok(())
2239}
2240
2241/// Updates other participants with changes to the worktree settings
2242async fn update_worktree_settings(
2243    message: proto::UpdateWorktreeSettings,
2244    session: MessageContext,
2245) -> Result<()> {
2246    let guest_connection_ids = session
2247        .db()
2248        .await
2249        .update_worktree_settings(&message, session.connection_id)
2250        .await?;
2251
2252    broadcast(
2253        Some(session.connection_id),
2254        guest_connection_ids.iter().copied(),
2255        |connection_id| {
2256            session
2257                .peer
2258                .forward_send(session.connection_id, connection_id, message.clone())
2259        },
2260    );
2261
2262    Ok(())
2263}
2264
2265/// Notify other participants that a language server has started.
2266async fn start_language_server(
2267    request: proto::StartLanguageServer,
2268    session: MessageContext,
2269) -> Result<()> {
2270    let guest_connection_ids = session
2271        .db()
2272        .await
2273        .start_language_server(&request, session.connection_id)
2274        .await?;
2275
2276    broadcast(
2277        Some(session.connection_id),
2278        guest_connection_ids.iter().copied(),
2279        |connection_id| {
2280            session
2281                .peer
2282                .forward_send(session.connection_id, connection_id, request.clone())
2283        },
2284    );
2285    Ok(())
2286}
2287
2288/// Notify other participants that a language server has changed.
2289async fn update_language_server(
2290    request: proto::UpdateLanguageServer,
2291    session: MessageContext,
2292) -> Result<()> {
2293    let project_id = ProjectId::from_proto(request.project_id);
2294    let db = session.db().await;
2295
2296    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2297    {
2298        if let Some(capabilities) = update.capabilities.clone() {
2299            db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2300                .await?;
2301        }
2302    }
2303
2304    let project_connection_ids = db
2305        .project_connection_ids(project_id, session.connection_id, true)
2306        .await?;
2307    broadcast(
2308        Some(session.connection_id),
2309        project_connection_ids.iter().copied(),
2310        |connection_id| {
2311            session
2312                .peer
2313                .forward_send(session.connection_id, connection_id, request.clone())
2314        },
2315    );
2316    Ok(())
2317}
2318
2319/// forward a project request to the host. These requests should be read only
2320/// as guests are allowed to send them.
2321async fn forward_read_only_project_request<T>(
2322    request: T,
2323    response: Response<T>,
2324    session: MessageContext,
2325) -> Result<()>
2326where
2327    T: EntityMessage + RequestMessage,
2328{
2329    let project_id = ProjectId::from_proto(request.remote_entity_id());
2330    let host_connection_id = session
2331        .db()
2332        .await
2333        .host_for_read_only_project_request(project_id, session.connection_id)
2334        .await?;
2335    let payload = session.forward_request(host_connection_id, request).await?;
2336    response.send(payload)?;
2337    Ok(())
2338}
2339
2340/// forward a project request to the host. These requests are disallowed
2341/// for guests.
2342async fn forward_mutating_project_request<T>(
2343    request: T,
2344    response: Response<T>,
2345    session: MessageContext,
2346) -> Result<()>
2347where
2348    T: EntityMessage + RequestMessage,
2349{
2350    let project_id = ProjectId::from_proto(request.remote_entity_id());
2351
2352    let host_connection_id = session
2353        .db()
2354        .await
2355        .host_for_mutating_project_request(project_id, session.connection_id)
2356        .await?;
2357    let payload = session.forward_request(host_connection_id, request).await?;
2358    response.send(payload)?;
2359    Ok(())
2360}
2361
2362async fn multi_lsp_query(
2363    request: MultiLspQuery,
2364    response: Response<MultiLspQuery>,
2365    session: MessageContext,
2366) -> Result<()> {
2367    tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2368    tracing::info!("multi_lsp_query message received");
2369    forward_mutating_project_request(request, response, session).await
2370}
2371
2372/// Notify other participants that a new buffer has been created
2373async fn create_buffer_for_peer(
2374    request: proto::CreateBufferForPeer,
2375    session: MessageContext,
2376) -> Result<()> {
2377    session
2378        .db()
2379        .await
2380        .check_user_is_project_host(
2381            ProjectId::from_proto(request.project_id),
2382            session.connection_id,
2383        )
2384        .await?;
2385    let peer_id = request.peer_id.context("invalid peer id")?;
2386    session
2387        .peer
2388        .forward_send(session.connection_id, peer_id.into(), request)?;
2389    Ok(())
2390}
2391
2392/// Notify other participants that a buffer has been updated. This is
2393/// allowed for guests as long as the update is limited to selections.
2394async fn update_buffer(
2395    request: proto::UpdateBuffer,
2396    response: Response<proto::UpdateBuffer>,
2397    session: MessageContext,
2398) -> Result<()> {
2399    let project_id = ProjectId::from_proto(request.project_id);
2400    let mut capability = Capability::ReadOnly;
2401
2402    for op in request.operations.iter() {
2403        match op.variant {
2404            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2405            Some(_) => capability = Capability::ReadWrite,
2406        }
2407    }
2408
2409    let host = {
2410        let guard = session
2411            .db()
2412            .await
2413            .connections_for_buffer_update(project_id, session.connection_id, capability)
2414            .await?;
2415
2416        let (host, guests) = &*guard;
2417
2418        broadcast(
2419            Some(session.connection_id),
2420            guests.clone(),
2421            |connection_id| {
2422                session
2423                    .peer
2424                    .forward_send(session.connection_id, connection_id, request.clone())
2425            },
2426        );
2427
2428        *host
2429    };
2430
2431    if host != session.connection_id {
2432        session.forward_request(host, request.clone()).await?;
2433    }
2434
2435    response.send(proto::Ack {})?;
2436    Ok(())
2437}
2438
2439async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2440    let project_id = ProjectId::from_proto(message.project_id);
2441
2442    let operation = message.operation.as_ref().context("invalid operation")?;
2443    let capability = match operation.variant.as_ref() {
2444        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2445            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2446                match buffer_op.variant {
2447                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2448                        Capability::ReadOnly
2449                    }
2450                    _ => Capability::ReadWrite,
2451                }
2452            } else {
2453                Capability::ReadWrite
2454            }
2455        }
2456        Some(_) => Capability::ReadWrite,
2457        None => Capability::ReadOnly,
2458    };
2459
2460    let guard = session
2461        .db()
2462        .await
2463        .connections_for_buffer_update(project_id, session.connection_id, capability)
2464        .await?;
2465
2466    let (host, guests) = &*guard;
2467
2468    broadcast(
2469        Some(session.connection_id),
2470        guests.iter().chain([host]).copied(),
2471        |connection_id| {
2472            session
2473                .peer
2474                .forward_send(session.connection_id, connection_id, message.clone())
2475        },
2476    );
2477
2478    Ok(())
2479}
2480
2481/// Notify other participants that a project has been updated.
2482async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2483    request: T,
2484    session: MessageContext,
2485) -> Result<()> {
2486    let project_id = ProjectId::from_proto(request.remote_entity_id());
2487    let project_connection_ids = session
2488        .db()
2489        .await
2490        .project_connection_ids(project_id, session.connection_id, false)
2491        .await?;
2492
2493    broadcast(
2494        Some(session.connection_id),
2495        project_connection_ids.iter().copied(),
2496        |connection_id| {
2497            session
2498                .peer
2499                .forward_send(session.connection_id, connection_id, request.clone())
2500        },
2501    );
2502    Ok(())
2503}
2504
2505/// Start following another user in a call.
2506async fn follow(
2507    request: proto::Follow,
2508    response: Response<proto::Follow>,
2509    session: MessageContext,
2510) -> Result<()> {
2511    let room_id = RoomId::from_proto(request.room_id);
2512    let project_id = request.project_id.map(ProjectId::from_proto);
2513    let leader_id = request.leader_id.context("invalid leader id")?.into();
2514    let follower_id = session.connection_id;
2515
2516    session
2517        .db()
2518        .await
2519        .check_room_participants(room_id, leader_id, session.connection_id)
2520        .await?;
2521
2522    let response_payload = session.forward_request(leader_id, request).await?;
2523    response.send(response_payload)?;
2524
2525    if let Some(project_id) = project_id {
2526        let room = session
2527            .db()
2528            .await
2529            .follow(room_id, project_id, leader_id, follower_id)
2530            .await?;
2531        room_updated(&room, &session.peer);
2532    }
2533
2534    Ok(())
2535}
2536
2537/// Stop following another user in a call.
2538async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2539    let room_id = RoomId::from_proto(request.room_id);
2540    let project_id = request.project_id.map(ProjectId::from_proto);
2541    let leader_id = request.leader_id.context("invalid leader id")?.into();
2542    let follower_id = session.connection_id;
2543
2544    session
2545        .db()
2546        .await
2547        .check_room_participants(room_id, leader_id, session.connection_id)
2548        .await?;
2549
2550    session
2551        .peer
2552        .forward_send(session.connection_id, leader_id, request)?;
2553
2554    if let Some(project_id) = project_id {
2555        let room = session
2556            .db()
2557            .await
2558            .unfollow(room_id, project_id, leader_id, follower_id)
2559            .await?;
2560        room_updated(&room, &session.peer);
2561    }
2562
2563    Ok(())
2564}
2565
2566/// Notify everyone following you of your current location.
2567async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2568    let room_id = RoomId::from_proto(request.room_id);
2569    let database = session.db.lock().await;
2570
2571    let connection_ids = if let Some(project_id) = request.project_id {
2572        let project_id = ProjectId::from_proto(project_id);
2573        database
2574            .project_connection_ids(project_id, session.connection_id, true)
2575            .await?
2576    } else {
2577        database
2578            .room_connection_ids(room_id, session.connection_id)
2579            .await?
2580    };
2581
2582    // For now, don't send view update messages back to that view's current leader.
2583    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2584        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2585        _ => None,
2586    });
2587
2588    for connection_id in connection_ids.iter().cloned() {
2589        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2590            session
2591                .peer
2592                .forward_send(session.connection_id, connection_id, request.clone())?;
2593        }
2594    }
2595    Ok(())
2596}
2597
2598/// Get public data about users.
2599async fn get_users(
2600    request: proto::GetUsers,
2601    response: Response<proto::GetUsers>,
2602    session: MessageContext,
2603) -> Result<()> {
2604    let user_ids = request
2605        .user_ids
2606        .into_iter()
2607        .map(UserId::from_proto)
2608        .collect();
2609    let users = session
2610        .db()
2611        .await
2612        .get_users_by_ids(user_ids)
2613        .await?
2614        .into_iter()
2615        .map(|user| proto::User {
2616            id: user.id.to_proto(),
2617            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2618            github_login: user.github_login,
2619            name: user.name,
2620        })
2621        .collect();
2622    response.send(proto::UsersResponse { users })?;
2623    Ok(())
2624}
2625
2626/// Search for users (to invite) buy Github login
2627async fn fuzzy_search_users(
2628    request: proto::FuzzySearchUsers,
2629    response: Response<proto::FuzzySearchUsers>,
2630    session: MessageContext,
2631) -> Result<()> {
2632    let query = request.query;
2633    let users = match query.len() {
2634        0 => vec![],
2635        1 | 2 => session
2636            .db()
2637            .await
2638            .get_user_by_github_login(&query)
2639            .await?
2640            .into_iter()
2641            .collect(),
2642        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2643    };
2644    let users = users
2645        .into_iter()
2646        .filter(|user| user.id != session.user_id())
2647        .map(|user| proto::User {
2648            id: user.id.to_proto(),
2649            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2650            github_login: user.github_login,
2651            name: user.name,
2652        })
2653        .collect();
2654    response.send(proto::UsersResponse { users })?;
2655    Ok(())
2656}
2657
2658/// Send a contact request to another user.
2659async fn request_contact(
2660    request: proto::RequestContact,
2661    response: Response<proto::RequestContact>,
2662    session: MessageContext,
2663) -> Result<()> {
2664    let requester_id = session.user_id();
2665    let responder_id = UserId::from_proto(request.responder_id);
2666    if requester_id == responder_id {
2667        return Err(anyhow!("cannot add yourself as a contact"))?;
2668    }
2669
2670    let notifications = session
2671        .db()
2672        .await
2673        .send_contact_request(requester_id, responder_id)
2674        .await?;
2675
2676    // Update outgoing contact requests of requester
2677    let mut update = proto::UpdateContacts::default();
2678    update.outgoing_requests.push(responder_id.to_proto());
2679    for connection_id in session
2680        .connection_pool()
2681        .await
2682        .user_connection_ids(requester_id)
2683    {
2684        session.peer.send(connection_id, update.clone())?;
2685    }
2686
2687    // Update incoming contact requests of responder
2688    let mut update = proto::UpdateContacts::default();
2689    update
2690        .incoming_requests
2691        .push(proto::IncomingContactRequest {
2692            requester_id: requester_id.to_proto(),
2693        });
2694    let connection_pool = session.connection_pool().await;
2695    for connection_id in connection_pool.user_connection_ids(responder_id) {
2696        session.peer.send(connection_id, update.clone())?;
2697    }
2698
2699    send_notifications(&connection_pool, &session.peer, notifications);
2700
2701    response.send(proto::Ack {})?;
2702    Ok(())
2703}
2704
2705/// Accept or decline a contact request
2706async fn respond_to_contact_request(
2707    request: proto::RespondToContactRequest,
2708    response: Response<proto::RespondToContactRequest>,
2709    session: MessageContext,
2710) -> Result<()> {
2711    let responder_id = session.user_id();
2712    let requester_id = UserId::from_proto(request.requester_id);
2713    let db = session.db().await;
2714    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2715        db.dismiss_contact_notification(responder_id, requester_id)
2716            .await?;
2717    } else {
2718        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2719
2720        let notifications = db
2721            .respond_to_contact_request(responder_id, requester_id, accept)
2722            .await?;
2723        let requester_busy = db.is_user_busy(requester_id).await?;
2724        let responder_busy = db.is_user_busy(responder_id).await?;
2725
2726        let pool = session.connection_pool().await;
2727        // Update responder with new contact
2728        let mut update = proto::UpdateContacts::default();
2729        if accept {
2730            update
2731                .contacts
2732                .push(contact_for_user(requester_id, requester_busy, &pool));
2733        }
2734        update
2735            .remove_incoming_requests
2736            .push(requester_id.to_proto());
2737        for connection_id in pool.user_connection_ids(responder_id) {
2738            session.peer.send(connection_id, update.clone())?;
2739        }
2740
2741        // Update requester with new contact
2742        let mut update = proto::UpdateContacts::default();
2743        if accept {
2744            update
2745                .contacts
2746                .push(contact_for_user(responder_id, responder_busy, &pool));
2747        }
2748        update
2749            .remove_outgoing_requests
2750            .push(responder_id.to_proto());
2751
2752        for connection_id in pool.user_connection_ids(requester_id) {
2753            session.peer.send(connection_id, update.clone())?;
2754        }
2755
2756        send_notifications(&pool, &session.peer, notifications);
2757    }
2758
2759    response.send(proto::Ack {})?;
2760    Ok(())
2761}
2762
2763/// Remove a contact.
2764async fn remove_contact(
2765    request: proto::RemoveContact,
2766    response: Response<proto::RemoveContact>,
2767    session: MessageContext,
2768) -> Result<()> {
2769    let requester_id = session.user_id();
2770    let responder_id = UserId::from_proto(request.user_id);
2771    let db = session.db().await;
2772    let (contact_accepted, deleted_notification_id) =
2773        db.remove_contact(requester_id, responder_id).await?;
2774
2775    let pool = session.connection_pool().await;
2776    // Update outgoing contact requests of requester
2777    let mut update = proto::UpdateContacts::default();
2778    if contact_accepted {
2779        update.remove_contacts.push(responder_id.to_proto());
2780    } else {
2781        update
2782            .remove_outgoing_requests
2783            .push(responder_id.to_proto());
2784    }
2785    for connection_id in pool.user_connection_ids(requester_id) {
2786        session.peer.send(connection_id, update.clone())?;
2787    }
2788
2789    // Update incoming contact requests of responder
2790    let mut update = proto::UpdateContacts::default();
2791    if contact_accepted {
2792        update.remove_contacts.push(requester_id.to_proto());
2793    } else {
2794        update
2795            .remove_incoming_requests
2796            .push(requester_id.to_proto());
2797    }
2798    for connection_id in pool.user_connection_ids(responder_id) {
2799        session.peer.send(connection_id, update.clone())?;
2800        if let Some(notification_id) = deleted_notification_id {
2801            session.peer.send(
2802                connection_id,
2803                proto::DeleteNotification {
2804                    notification_id: notification_id.to_proto(),
2805                },
2806            )?;
2807        }
2808    }
2809
2810    response.send(proto::Ack {})?;
2811    Ok(())
2812}
2813
2814fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2815    version.0.minor() < 139
2816}
2817
2818async fn subscribe_to_channels(
2819    _: proto::SubscribeToChannels,
2820    session: MessageContext,
2821) -> Result<()> {
2822    subscribe_user_to_channels(session.user_id(), &session).await?;
2823    Ok(())
2824}
2825
2826async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2827    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2828    let mut pool = session.connection_pool().await;
2829    for membership in &channels_for_user.channel_memberships {
2830        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2831    }
2832    session.peer.send(
2833        session.connection_id,
2834        build_update_user_channels(&channels_for_user),
2835    )?;
2836    session.peer.send(
2837        session.connection_id,
2838        build_channels_update(channels_for_user),
2839    )?;
2840    Ok(())
2841}
2842
2843/// Creates a new channel.
2844async fn create_channel(
2845    request: proto::CreateChannel,
2846    response: Response<proto::CreateChannel>,
2847    session: MessageContext,
2848) -> Result<()> {
2849    let db = session.db().await;
2850
2851    let parent_id = request.parent_id.map(ChannelId::from_proto);
2852    let (channel, membership) = db
2853        .create_channel(&request.name, parent_id, session.user_id())
2854        .await?;
2855
2856    let root_id = channel.root_id();
2857    let channel = Channel::from_model(channel);
2858
2859    response.send(proto::CreateChannelResponse {
2860        channel: Some(channel.to_proto()),
2861        parent_id: request.parent_id,
2862    })?;
2863
2864    let mut connection_pool = session.connection_pool().await;
2865    if let Some(membership) = membership {
2866        connection_pool.subscribe_to_channel(
2867            membership.user_id,
2868            membership.channel_id,
2869            membership.role,
2870        );
2871        let update = proto::UpdateUserChannels {
2872            channel_memberships: vec![proto::ChannelMembership {
2873                channel_id: membership.channel_id.to_proto(),
2874                role: membership.role.into(),
2875            }],
2876            ..Default::default()
2877        };
2878        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2879            session.peer.send(connection_id, update.clone())?;
2880        }
2881    }
2882
2883    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2884        if !role.can_see_channel(channel.visibility) {
2885            continue;
2886        }
2887
2888        let update = proto::UpdateChannels {
2889            channels: vec![channel.to_proto()],
2890            ..Default::default()
2891        };
2892        session.peer.send(connection_id, update.clone())?;
2893    }
2894
2895    Ok(())
2896}
2897
2898/// Delete a channel
2899async fn delete_channel(
2900    request: proto::DeleteChannel,
2901    response: Response<proto::DeleteChannel>,
2902    session: MessageContext,
2903) -> Result<()> {
2904    let db = session.db().await;
2905
2906    let channel_id = request.channel_id;
2907    let (root_channel, removed_channels) = db
2908        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2909        .await?;
2910    response.send(proto::Ack {})?;
2911
2912    // Notify members of removed channels
2913    let mut update = proto::UpdateChannels::default();
2914    update
2915        .delete_channels
2916        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2917
2918    let connection_pool = session.connection_pool().await;
2919    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2920        session.peer.send(connection_id, update.clone())?;
2921    }
2922
2923    Ok(())
2924}
2925
2926/// Invite someone to join a channel.
2927async fn invite_channel_member(
2928    request: proto::InviteChannelMember,
2929    response: Response<proto::InviteChannelMember>,
2930    session: MessageContext,
2931) -> Result<()> {
2932    let db = session.db().await;
2933    let channel_id = ChannelId::from_proto(request.channel_id);
2934    let invitee_id = UserId::from_proto(request.user_id);
2935    let InviteMemberResult {
2936        channel,
2937        notifications,
2938    } = db
2939        .invite_channel_member(
2940            channel_id,
2941            invitee_id,
2942            session.user_id(),
2943            request.role().into(),
2944        )
2945        .await?;
2946
2947    let update = proto::UpdateChannels {
2948        channel_invitations: vec![channel.to_proto()],
2949        ..Default::default()
2950    };
2951
2952    let connection_pool = session.connection_pool().await;
2953    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2954        session.peer.send(connection_id, update.clone())?;
2955    }
2956
2957    send_notifications(&connection_pool, &session.peer, notifications);
2958
2959    response.send(proto::Ack {})?;
2960    Ok(())
2961}
2962
2963/// remove someone from a channel
2964async fn remove_channel_member(
2965    request: proto::RemoveChannelMember,
2966    response: Response<proto::RemoveChannelMember>,
2967    session: MessageContext,
2968) -> Result<()> {
2969    let db = session.db().await;
2970    let channel_id = ChannelId::from_proto(request.channel_id);
2971    let member_id = UserId::from_proto(request.user_id);
2972
2973    let RemoveChannelMemberResult {
2974        membership_update,
2975        notification_id,
2976    } = db
2977        .remove_channel_member(channel_id, member_id, session.user_id())
2978        .await?;
2979
2980    let mut connection_pool = session.connection_pool().await;
2981    notify_membership_updated(
2982        &mut connection_pool,
2983        membership_update,
2984        member_id,
2985        &session.peer,
2986    );
2987    for connection_id in connection_pool.user_connection_ids(member_id) {
2988        if let Some(notification_id) = notification_id {
2989            session
2990                .peer
2991                .send(
2992                    connection_id,
2993                    proto::DeleteNotification {
2994                        notification_id: notification_id.to_proto(),
2995                    },
2996                )
2997                .trace_err();
2998        }
2999    }
3000
3001    response.send(proto::Ack {})?;
3002    Ok(())
3003}
3004
3005/// Toggle the channel between public and private.
3006/// Care is taken to maintain the invariant that public channels only descend from public channels,
3007/// (though members-only channels can appear at any point in the hierarchy).
3008async fn set_channel_visibility(
3009    request: proto::SetChannelVisibility,
3010    response: Response<proto::SetChannelVisibility>,
3011    session: MessageContext,
3012) -> Result<()> {
3013    let db = session.db().await;
3014    let channel_id = ChannelId::from_proto(request.channel_id);
3015    let visibility = request.visibility().into();
3016
3017    let channel_model = db
3018        .set_channel_visibility(channel_id, visibility, session.user_id())
3019        .await?;
3020    let root_id = channel_model.root_id();
3021    let channel = Channel::from_model(channel_model);
3022
3023    let mut connection_pool = session.connection_pool().await;
3024    for (user_id, role) in connection_pool
3025        .channel_user_ids(root_id)
3026        .collect::<Vec<_>>()
3027        .into_iter()
3028    {
3029        let update = if role.can_see_channel(channel.visibility) {
3030            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3031            proto::UpdateChannels {
3032                channels: vec![channel.to_proto()],
3033                ..Default::default()
3034            }
3035        } else {
3036            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3037            proto::UpdateChannels {
3038                delete_channels: vec![channel.id.to_proto()],
3039                ..Default::default()
3040            }
3041        };
3042
3043        for connection_id in connection_pool.user_connection_ids(user_id) {
3044            session.peer.send(connection_id, update.clone())?;
3045        }
3046    }
3047
3048    response.send(proto::Ack {})?;
3049    Ok(())
3050}
3051
3052/// Alter the role for a user in the channel.
3053async fn set_channel_member_role(
3054    request: proto::SetChannelMemberRole,
3055    response: Response<proto::SetChannelMemberRole>,
3056    session: MessageContext,
3057) -> Result<()> {
3058    let db = session.db().await;
3059    let channel_id = ChannelId::from_proto(request.channel_id);
3060    let member_id = UserId::from_proto(request.user_id);
3061    let result = db
3062        .set_channel_member_role(
3063            channel_id,
3064            session.user_id(),
3065            member_id,
3066            request.role().into(),
3067        )
3068        .await?;
3069
3070    match result {
3071        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3072            let mut connection_pool = session.connection_pool().await;
3073            notify_membership_updated(
3074                &mut connection_pool,
3075                membership_update,
3076                member_id,
3077                &session.peer,
3078            )
3079        }
3080        db::SetMemberRoleResult::InviteUpdated(channel) => {
3081            let update = proto::UpdateChannels {
3082                channel_invitations: vec![channel.to_proto()],
3083                ..Default::default()
3084            };
3085
3086            for connection_id in session
3087                .connection_pool()
3088                .await
3089                .user_connection_ids(member_id)
3090            {
3091                session.peer.send(connection_id, update.clone())?;
3092            }
3093        }
3094    }
3095
3096    response.send(proto::Ack {})?;
3097    Ok(())
3098}
3099
3100/// Change the name of a channel
3101async fn rename_channel(
3102    request: proto::RenameChannel,
3103    response: Response<proto::RenameChannel>,
3104    session: MessageContext,
3105) -> Result<()> {
3106    let db = session.db().await;
3107    let channel_id = ChannelId::from_proto(request.channel_id);
3108    let channel_model = db
3109        .rename_channel(channel_id, session.user_id(), &request.name)
3110        .await?;
3111    let root_id = channel_model.root_id();
3112    let channel = Channel::from_model(channel_model);
3113
3114    response.send(proto::RenameChannelResponse {
3115        channel: Some(channel.to_proto()),
3116    })?;
3117
3118    let connection_pool = session.connection_pool().await;
3119    let update = proto::UpdateChannels {
3120        channels: vec![channel.to_proto()],
3121        ..Default::default()
3122    };
3123    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3124        if role.can_see_channel(channel.visibility) {
3125            session.peer.send(connection_id, update.clone())?;
3126        }
3127    }
3128
3129    Ok(())
3130}
3131
3132/// Move a channel to a new parent.
3133async fn move_channel(
3134    request: proto::MoveChannel,
3135    response: Response<proto::MoveChannel>,
3136    session: MessageContext,
3137) -> Result<()> {
3138    let channel_id = ChannelId::from_proto(request.channel_id);
3139    let to = ChannelId::from_proto(request.to);
3140
3141    let (root_id, channels) = session
3142        .db()
3143        .await
3144        .move_channel(channel_id, to, session.user_id())
3145        .await?;
3146
3147    let connection_pool = session.connection_pool().await;
3148    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3149        let channels = channels
3150            .iter()
3151            .filter_map(|channel| {
3152                if role.can_see_channel(channel.visibility) {
3153                    Some(channel.to_proto())
3154                } else {
3155                    None
3156                }
3157            })
3158            .collect::<Vec<_>>();
3159        if channels.is_empty() {
3160            continue;
3161        }
3162
3163        let update = proto::UpdateChannels {
3164            channels,
3165            ..Default::default()
3166        };
3167
3168        session.peer.send(connection_id, update.clone())?;
3169    }
3170
3171    response.send(Ack {})?;
3172    Ok(())
3173}
3174
3175async fn reorder_channel(
3176    request: proto::ReorderChannel,
3177    response: Response<proto::ReorderChannel>,
3178    session: MessageContext,
3179) -> Result<()> {
3180    let channel_id = ChannelId::from_proto(request.channel_id);
3181    let direction = request.direction();
3182
3183    let updated_channels = session
3184        .db()
3185        .await
3186        .reorder_channel(channel_id, direction, session.user_id())
3187        .await?;
3188
3189    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3190        let connection_pool = session.connection_pool().await;
3191        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3192            let channels = updated_channels
3193                .iter()
3194                .filter_map(|channel| {
3195                    if role.can_see_channel(channel.visibility) {
3196                        Some(channel.to_proto())
3197                    } else {
3198                        None
3199                    }
3200                })
3201                .collect::<Vec<_>>();
3202
3203            if channels.is_empty() {
3204                continue;
3205            }
3206
3207            let update = proto::UpdateChannels {
3208                channels,
3209                ..Default::default()
3210            };
3211
3212            session.peer.send(connection_id, update.clone())?;
3213        }
3214    }
3215
3216    response.send(Ack {})?;
3217    Ok(())
3218}
3219
3220/// Get the list of channel members
3221async fn get_channel_members(
3222    request: proto::GetChannelMembers,
3223    response: Response<proto::GetChannelMembers>,
3224    session: MessageContext,
3225) -> Result<()> {
3226    let db = session.db().await;
3227    let channel_id = ChannelId::from_proto(request.channel_id);
3228    let limit = if request.limit == 0 {
3229        u16::MAX as u64
3230    } else {
3231        request.limit
3232    };
3233    let (members, users) = db
3234        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3235        .await?;
3236    response.send(proto::GetChannelMembersResponse { members, users })?;
3237    Ok(())
3238}
3239
3240/// Accept or decline a channel invitation.
3241async fn respond_to_channel_invite(
3242    request: proto::RespondToChannelInvite,
3243    response: Response<proto::RespondToChannelInvite>,
3244    session: MessageContext,
3245) -> Result<()> {
3246    let db = session.db().await;
3247    let channel_id = ChannelId::from_proto(request.channel_id);
3248    let RespondToChannelInvite {
3249        membership_update,
3250        notifications,
3251    } = db
3252        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3253        .await?;
3254
3255    let mut connection_pool = session.connection_pool().await;
3256    if let Some(membership_update) = membership_update {
3257        notify_membership_updated(
3258            &mut connection_pool,
3259            membership_update,
3260            session.user_id(),
3261            &session.peer,
3262        );
3263    } else {
3264        let update = proto::UpdateChannels {
3265            remove_channel_invitations: vec![channel_id.to_proto()],
3266            ..Default::default()
3267        };
3268
3269        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3270            session.peer.send(connection_id, update.clone())?;
3271        }
3272    };
3273
3274    send_notifications(&connection_pool, &session.peer, notifications);
3275
3276    response.send(proto::Ack {})?;
3277
3278    Ok(())
3279}
3280
3281/// Join the channels' room
3282async fn join_channel(
3283    request: proto::JoinChannel,
3284    response: Response<proto::JoinChannel>,
3285    session: MessageContext,
3286) -> Result<()> {
3287    let channel_id = ChannelId::from_proto(request.channel_id);
3288    join_channel_internal(channel_id, Box::new(response), session).await
3289}
3290
3291trait JoinChannelInternalResponse {
3292    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3293}
3294impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3295    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3296        Response::<proto::JoinChannel>::send(self, result)
3297    }
3298}
3299impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3300    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3301        Response::<proto::JoinRoom>::send(self, result)
3302    }
3303}
3304
3305async fn join_channel_internal(
3306    channel_id: ChannelId,
3307    response: Box<impl JoinChannelInternalResponse>,
3308    session: MessageContext,
3309) -> Result<()> {
3310    let joined_room = {
3311        let mut db = session.db().await;
3312        // If zed quits without leaving the room, and the user re-opens zed before the
3313        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3314        // room they were in.
3315        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3316            tracing::info!(
3317                stale_connection_id = %connection,
3318                "cleaning up stale connection",
3319            );
3320            drop(db);
3321            leave_room_for_session(&session, connection).await?;
3322            db = session.db().await;
3323        }
3324
3325        let (joined_room, membership_updated, role) = db
3326            .join_channel(channel_id, session.user_id(), session.connection_id)
3327            .await?;
3328
3329        let live_kit_connection_info =
3330            session
3331                .app_state
3332                .livekit_client
3333                .as_ref()
3334                .and_then(|live_kit| {
3335                    let (can_publish, token) = if role == ChannelRole::Guest {
3336                        (
3337                            false,
3338                            live_kit
3339                                .guest_token(
3340                                    &joined_room.room.livekit_room,
3341                                    &session.user_id().to_string(),
3342                                )
3343                                .trace_err()?,
3344                        )
3345                    } else {
3346                        (
3347                            true,
3348                            live_kit
3349                                .room_token(
3350                                    &joined_room.room.livekit_room,
3351                                    &session.user_id().to_string(),
3352                                )
3353                                .trace_err()?,
3354                        )
3355                    };
3356
3357                    Some(LiveKitConnectionInfo {
3358                        server_url: live_kit.url().into(),
3359                        token,
3360                        can_publish,
3361                    })
3362                });
3363
3364        response.send(proto::JoinRoomResponse {
3365            room: Some(joined_room.room.clone()),
3366            channel_id: joined_room
3367                .channel
3368                .as_ref()
3369                .map(|channel| channel.id.to_proto()),
3370            live_kit_connection_info,
3371        })?;
3372
3373        let mut connection_pool = session.connection_pool().await;
3374        if let Some(membership_updated) = membership_updated {
3375            notify_membership_updated(
3376                &mut connection_pool,
3377                membership_updated,
3378                session.user_id(),
3379                &session.peer,
3380            );
3381        }
3382
3383        room_updated(&joined_room.room, &session.peer);
3384
3385        joined_room
3386    };
3387
3388    channel_updated(
3389        &joined_room.channel.context("channel not returned")?,
3390        &joined_room.room,
3391        &session.peer,
3392        &*session.connection_pool().await,
3393    );
3394
3395    update_user_contacts(session.user_id(), &session).await?;
3396    Ok(())
3397}
3398
3399/// Start editing the channel notes
3400async fn join_channel_buffer(
3401    request: proto::JoinChannelBuffer,
3402    response: Response<proto::JoinChannelBuffer>,
3403    session: MessageContext,
3404) -> Result<()> {
3405    let db = session.db().await;
3406    let channel_id = ChannelId::from_proto(request.channel_id);
3407
3408    let open_response = db
3409        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3410        .await?;
3411
3412    let collaborators = open_response.collaborators.clone();
3413    response.send(open_response)?;
3414
3415    let update = UpdateChannelBufferCollaborators {
3416        channel_id: channel_id.to_proto(),
3417        collaborators: collaborators.clone(),
3418    };
3419    channel_buffer_updated(
3420        session.connection_id,
3421        collaborators
3422            .iter()
3423            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3424        &update,
3425        &session.peer,
3426    );
3427
3428    Ok(())
3429}
3430
3431/// Edit the channel notes
3432async fn update_channel_buffer(
3433    request: proto::UpdateChannelBuffer,
3434    session: MessageContext,
3435) -> Result<()> {
3436    let db = session.db().await;
3437    let channel_id = ChannelId::from_proto(request.channel_id);
3438
3439    let (collaborators, epoch, version) = db
3440        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3441        .await?;
3442
3443    channel_buffer_updated(
3444        session.connection_id,
3445        collaborators.clone(),
3446        &proto::UpdateChannelBuffer {
3447            channel_id: channel_id.to_proto(),
3448            operations: request.operations,
3449        },
3450        &session.peer,
3451    );
3452
3453    let pool = &*session.connection_pool().await;
3454
3455    let non_collaborators =
3456        pool.channel_connection_ids(channel_id)
3457            .filter_map(|(connection_id, _)| {
3458                if collaborators.contains(&connection_id) {
3459                    None
3460                } else {
3461                    Some(connection_id)
3462                }
3463            });
3464
3465    broadcast(None, non_collaborators, |peer_id| {
3466        session.peer.send(
3467            peer_id,
3468            proto::UpdateChannels {
3469                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3470                    channel_id: channel_id.to_proto(),
3471                    epoch: epoch as u64,
3472                    version: version.clone(),
3473                }],
3474                ..Default::default()
3475            },
3476        )
3477    });
3478
3479    Ok(())
3480}
3481
3482/// Rejoin the channel notes after a connection blip
3483async fn rejoin_channel_buffers(
3484    request: proto::RejoinChannelBuffers,
3485    response: Response<proto::RejoinChannelBuffers>,
3486    session: MessageContext,
3487) -> Result<()> {
3488    let db = session.db().await;
3489    let buffers = db
3490        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3491        .await?;
3492
3493    for rejoined_buffer in &buffers {
3494        let collaborators_to_notify = rejoined_buffer
3495            .buffer
3496            .collaborators
3497            .iter()
3498            .filter_map(|c| Some(c.peer_id?.into()));
3499        channel_buffer_updated(
3500            session.connection_id,
3501            collaborators_to_notify,
3502            &proto::UpdateChannelBufferCollaborators {
3503                channel_id: rejoined_buffer.buffer.channel_id,
3504                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3505            },
3506            &session.peer,
3507        );
3508    }
3509
3510    response.send(proto::RejoinChannelBuffersResponse {
3511        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3512    })?;
3513
3514    Ok(())
3515}
3516
3517/// Stop editing the channel notes
3518async fn leave_channel_buffer(
3519    request: proto::LeaveChannelBuffer,
3520    response: Response<proto::LeaveChannelBuffer>,
3521    session: MessageContext,
3522) -> Result<()> {
3523    let db = session.db().await;
3524    let channel_id = ChannelId::from_proto(request.channel_id);
3525
3526    let left_buffer = db
3527        .leave_channel_buffer(channel_id, session.connection_id)
3528        .await?;
3529
3530    response.send(Ack {})?;
3531
3532    channel_buffer_updated(
3533        session.connection_id,
3534        left_buffer.connections,
3535        &proto::UpdateChannelBufferCollaborators {
3536            channel_id: channel_id.to_proto(),
3537            collaborators: left_buffer.collaborators,
3538        },
3539        &session.peer,
3540    );
3541
3542    Ok(())
3543}
3544
3545fn channel_buffer_updated<T: EnvelopedMessage>(
3546    sender_id: ConnectionId,
3547    collaborators: impl IntoIterator<Item = ConnectionId>,
3548    message: &T,
3549    peer: &Peer,
3550) {
3551    broadcast(Some(sender_id), collaborators, |peer_id| {
3552        peer.send(peer_id, message.clone())
3553    });
3554}
3555
3556fn send_notifications(
3557    connection_pool: &ConnectionPool,
3558    peer: &Peer,
3559    notifications: db::NotificationBatch,
3560) {
3561    for (user_id, notification) in notifications {
3562        for connection_id in connection_pool.user_connection_ids(user_id) {
3563            if let Err(error) = peer.send(
3564                connection_id,
3565                proto::AddNotification {
3566                    notification: Some(notification.clone()),
3567                },
3568            ) {
3569                tracing::error!(
3570                    "failed to send notification to {:?} {}",
3571                    connection_id,
3572                    error
3573                );
3574            }
3575        }
3576    }
3577}
3578
3579/// Send a message to the channel
3580async fn send_channel_message(
3581    request: proto::SendChannelMessage,
3582    response: Response<proto::SendChannelMessage>,
3583    session: MessageContext,
3584) -> Result<()> {
3585    // Validate the message body.
3586    let body = request.body.trim().to_string();
3587    if body.len() > MAX_MESSAGE_LEN {
3588        return Err(anyhow!("message is too long"))?;
3589    }
3590    if body.is_empty() {
3591        return Err(anyhow!("message can't be blank"))?;
3592    }
3593
3594    // TODO: adjust mentions if body is trimmed
3595
3596    let timestamp = OffsetDateTime::now_utc();
3597    let nonce = request.nonce.context("nonce can't be blank")?;
3598
3599    let channel_id = ChannelId::from_proto(request.channel_id);
3600    let CreatedChannelMessage {
3601        message_id,
3602        participant_connection_ids,
3603        notifications,
3604    } = session
3605        .db()
3606        .await
3607        .create_channel_message(
3608            channel_id,
3609            session.user_id(),
3610            &body,
3611            &request.mentions,
3612            timestamp,
3613            nonce.clone().into(),
3614            request.reply_to_message_id.map(MessageId::from_proto),
3615        )
3616        .await?;
3617
3618    let message = proto::ChannelMessage {
3619        sender_id: session.user_id().to_proto(),
3620        id: message_id.to_proto(),
3621        body,
3622        mentions: request.mentions,
3623        timestamp: timestamp.unix_timestamp() as u64,
3624        nonce: Some(nonce),
3625        reply_to_message_id: request.reply_to_message_id,
3626        edited_at: None,
3627    };
3628    broadcast(
3629        Some(session.connection_id),
3630        participant_connection_ids.clone(),
3631        |connection| {
3632            session.peer.send(
3633                connection,
3634                proto::ChannelMessageSent {
3635                    channel_id: channel_id.to_proto(),
3636                    message: Some(message.clone()),
3637                },
3638            )
3639        },
3640    );
3641    response.send(proto::SendChannelMessageResponse {
3642        message: Some(message),
3643    })?;
3644
3645    let pool = &*session.connection_pool().await;
3646    let non_participants =
3647        pool.channel_connection_ids(channel_id)
3648            .filter_map(|(connection_id, _)| {
3649                if participant_connection_ids.contains(&connection_id) {
3650                    None
3651                } else {
3652                    Some(connection_id)
3653                }
3654            });
3655    broadcast(None, non_participants, |peer_id| {
3656        session.peer.send(
3657            peer_id,
3658            proto::UpdateChannels {
3659                latest_channel_message_ids: vec![proto::ChannelMessageId {
3660                    channel_id: channel_id.to_proto(),
3661                    message_id: message_id.to_proto(),
3662                }],
3663                ..Default::default()
3664            },
3665        )
3666    });
3667    send_notifications(pool, &session.peer, notifications);
3668
3669    Ok(())
3670}
3671
3672/// Delete a channel message
3673async fn remove_channel_message(
3674    request: proto::RemoveChannelMessage,
3675    response: Response<proto::RemoveChannelMessage>,
3676    session: MessageContext,
3677) -> Result<()> {
3678    let channel_id = ChannelId::from_proto(request.channel_id);
3679    let message_id = MessageId::from_proto(request.message_id);
3680    let (connection_ids, existing_notification_ids) = session
3681        .db()
3682        .await
3683        .remove_channel_message(channel_id, message_id, session.user_id())
3684        .await?;
3685
3686    broadcast(
3687        Some(session.connection_id),
3688        connection_ids,
3689        move |connection| {
3690            session.peer.send(connection, request.clone())?;
3691
3692            for notification_id in &existing_notification_ids {
3693                session.peer.send(
3694                    connection,
3695                    proto::DeleteNotification {
3696                        notification_id: (*notification_id).to_proto(),
3697                    },
3698                )?;
3699            }
3700
3701            Ok(())
3702        },
3703    );
3704    response.send(proto::Ack {})?;
3705    Ok(())
3706}
3707
3708async fn update_channel_message(
3709    request: proto::UpdateChannelMessage,
3710    response: Response<proto::UpdateChannelMessage>,
3711    session: MessageContext,
3712) -> Result<()> {
3713    let channel_id = ChannelId::from_proto(request.channel_id);
3714    let message_id = MessageId::from_proto(request.message_id);
3715    let updated_at = OffsetDateTime::now_utc();
3716    let UpdatedChannelMessage {
3717        message_id,
3718        participant_connection_ids,
3719        notifications,
3720        reply_to_message_id,
3721        timestamp,
3722        deleted_mention_notification_ids,
3723        updated_mention_notifications,
3724    } = session
3725        .db()
3726        .await
3727        .update_channel_message(
3728            channel_id,
3729            message_id,
3730            session.user_id(),
3731            request.body.as_str(),
3732            &request.mentions,
3733            updated_at,
3734        )
3735        .await?;
3736
3737    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3738
3739    let message = proto::ChannelMessage {
3740        sender_id: session.user_id().to_proto(),
3741        id: message_id.to_proto(),
3742        body: request.body.clone(),
3743        mentions: request.mentions.clone(),
3744        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3745        nonce: Some(nonce),
3746        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3747        edited_at: Some(updated_at.unix_timestamp() as u64),
3748    };
3749
3750    response.send(proto::Ack {})?;
3751
3752    let pool = &*session.connection_pool().await;
3753    broadcast(
3754        Some(session.connection_id),
3755        participant_connection_ids,
3756        |connection| {
3757            session.peer.send(
3758                connection,
3759                proto::ChannelMessageUpdate {
3760                    channel_id: channel_id.to_proto(),
3761                    message: Some(message.clone()),
3762                },
3763            )?;
3764
3765            for notification_id in &deleted_mention_notification_ids {
3766                session.peer.send(
3767                    connection,
3768                    proto::DeleteNotification {
3769                        notification_id: (*notification_id).to_proto(),
3770                    },
3771                )?;
3772            }
3773
3774            for notification in &updated_mention_notifications {
3775                session.peer.send(
3776                    connection,
3777                    proto::UpdateNotification {
3778                        notification: Some(notification.clone()),
3779                    },
3780                )?;
3781            }
3782
3783            Ok(())
3784        },
3785    );
3786
3787    send_notifications(pool, &session.peer, notifications);
3788
3789    Ok(())
3790}
3791
3792/// Mark a channel message as read
3793async fn acknowledge_channel_message(
3794    request: proto::AckChannelMessage,
3795    session: MessageContext,
3796) -> Result<()> {
3797    let channel_id = ChannelId::from_proto(request.channel_id);
3798    let message_id = MessageId::from_proto(request.message_id);
3799    let notifications = session
3800        .db()
3801        .await
3802        .observe_channel_message(channel_id, session.user_id(), message_id)
3803        .await?;
3804    send_notifications(
3805        &*session.connection_pool().await,
3806        &session.peer,
3807        notifications,
3808    );
3809    Ok(())
3810}
3811
3812/// Mark a buffer version as synced
3813async fn acknowledge_buffer_version(
3814    request: proto::AckBufferOperation,
3815    session: MessageContext,
3816) -> Result<()> {
3817    let buffer_id = BufferId::from_proto(request.buffer_id);
3818    session
3819        .db()
3820        .await
3821        .observe_buffer_version(
3822            buffer_id,
3823            session.user_id(),
3824            request.epoch as i32,
3825            &request.version,
3826        )
3827        .await?;
3828    Ok(())
3829}
3830
3831/// Get a Supermaven API key for the user
3832async fn get_supermaven_api_key(
3833    _request: proto::GetSupermavenApiKey,
3834    response: Response<proto::GetSupermavenApiKey>,
3835    session: MessageContext,
3836) -> Result<()> {
3837    let user_id: String = session.user_id().to_string();
3838    if !session.is_staff() {
3839        return Err(anyhow!("supermaven not enabled for this account"))?;
3840    }
3841
3842    let email = session.email().context("user must have an email")?;
3843
3844    let supermaven_admin_api = session
3845        .supermaven_client
3846        .as_ref()
3847        .context("supermaven not configured")?;
3848
3849    let result = supermaven_admin_api
3850        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3851        .await?;
3852
3853    response.send(proto::GetSupermavenApiKeyResponse {
3854        api_key: result.api_key,
3855    })?;
3856
3857    Ok(())
3858}
3859
3860/// Start receiving chat updates for a channel
3861async fn join_channel_chat(
3862    request: proto::JoinChannelChat,
3863    response: Response<proto::JoinChannelChat>,
3864    session: MessageContext,
3865) -> Result<()> {
3866    let channel_id = ChannelId::from_proto(request.channel_id);
3867
3868    let db = session.db().await;
3869    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3870        .await?;
3871    let messages = db
3872        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3873        .await?;
3874    response.send(proto::JoinChannelChatResponse {
3875        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3876        messages,
3877    })?;
3878    Ok(())
3879}
3880
3881/// Stop receiving chat updates for a channel
3882async fn leave_channel_chat(
3883    request: proto::LeaveChannelChat,
3884    session: MessageContext,
3885) -> Result<()> {
3886    let channel_id = ChannelId::from_proto(request.channel_id);
3887    session
3888        .db()
3889        .await
3890        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3891        .await?;
3892    Ok(())
3893}
3894
3895/// Retrieve the chat history for a channel
3896async fn get_channel_messages(
3897    request: proto::GetChannelMessages,
3898    response: Response<proto::GetChannelMessages>,
3899    session: MessageContext,
3900) -> Result<()> {
3901    let channel_id = ChannelId::from_proto(request.channel_id);
3902    let messages = session
3903        .db()
3904        .await
3905        .get_channel_messages(
3906            channel_id,
3907            session.user_id(),
3908            MESSAGE_COUNT_PER_PAGE,
3909            Some(MessageId::from_proto(request.before_message_id)),
3910        )
3911        .await?;
3912    response.send(proto::GetChannelMessagesResponse {
3913        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3914        messages,
3915    })?;
3916    Ok(())
3917}
3918
3919/// Retrieve specific chat messages
3920async fn get_channel_messages_by_id(
3921    request: proto::GetChannelMessagesById,
3922    response: Response<proto::GetChannelMessagesById>,
3923    session: MessageContext,
3924) -> Result<()> {
3925    let message_ids = request
3926        .message_ids
3927        .iter()
3928        .map(|id| MessageId::from_proto(*id))
3929        .collect::<Vec<_>>();
3930    let messages = session
3931        .db()
3932        .await
3933        .get_channel_messages_by_id(session.user_id(), &message_ids)
3934        .await?;
3935    response.send(proto::GetChannelMessagesResponse {
3936        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3937        messages,
3938    })?;
3939    Ok(())
3940}
3941
3942/// Retrieve the current users notifications
3943async fn get_notifications(
3944    request: proto::GetNotifications,
3945    response: Response<proto::GetNotifications>,
3946    session: MessageContext,
3947) -> Result<()> {
3948    let notifications = session
3949        .db()
3950        .await
3951        .get_notifications(
3952            session.user_id(),
3953            NOTIFICATION_COUNT_PER_PAGE,
3954            request.before_id.map(db::NotificationId::from_proto),
3955        )
3956        .await?;
3957    response.send(proto::GetNotificationsResponse {
3958        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3959        notifications,
3960    })?;
3961    Ok(())
3962}
3963
3964/// Mark notifications as read
3965async fn mark_notification_as_read(
3966    request: proto::MarkNotificationRead,
3967    response: Response<proto::MarkNotificationRead>,
3968    session: MessageContext,
3969) -> Result<()> {
3970    let database = &session.db().await;
3971    let notifications = database
3972        .mark_notification_as_read_by_id(
3973            session.user_id(),
3974            NotificationId::from_proto(request.notification_id),
3975        )
3976        .await?;
3977    send_notifications(
3978        &*session.connection_pool().await,
3979        &session.peer,
3980        notifications,
3981    );
3982    response.send(proto::Ack {})?;
3983    Ok(())
3984}
3985
3986fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3987    let message = match message {
3988        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3989        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3990        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3991        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3992        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3993            code: frame.code.into(),
3994            reason: frame.reason.as_str().to_owned().into(),
3995        })),
3996        // We should never receive a frame while reading the message, according
3997        // to the `tungstenite` maintainers:
3998        //
3999        // > It cannot occur when you read messages from the WebSocket, but it
4000        // > can be used when you want to send the raw frames (e.g. you want to
4001        // > send the frames to the WebSocket without composing the full message first).
4002        // >
4003        // > — https://github.com/snapview/tungstenite-rs/issues/268
4004        TungsteniteMessage::Frame(_) => {
4005            bail!("received an unexpected frame while reading the message")
4006        }
4007    };
4008
4009    Ok(message)
4010}
4011
4012fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4013    match message {
4014        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4015        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4016        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4017        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4018        AxumMessage::Close(frame) => {
4019            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4020                code: frame.code.into(),
4021                reason: frame.reason.as_ref().into(),
4022            }))
4023        }
4024    }
4025}
4026
4027fn notify_membership_updated(
4028    connection_pool: &mut ConnectionPool,
4029    result: MembershipUpdated,
4030    user_id: UserId,
4031    peer: &Peer,
4032) {
4033    for membership in &result.new_channels.channel_memberships {
4034        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4035    }
4036    for channel_id in &result.removed_channels {
4037        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4038    }
4039
4040    let user_channels_update = proto::UpdateUserChannels {
4041        channel_memberships: result
4042            .new_channels
4043            .channel_memberships
4044            .iter()
4045            .map(|cm| proto::ChannelMembership {
4046                channel_id: cm.channel_id.to_proto(),
4047                role: cm.role.into(),
4048            })
4049            .collect(),
4050        ..Default::default()
4051    };
4052
4053    let mut update = build_channels_update(result.new_channels);
4054    update.delete_channels = result
4055        .removed_channels
4056        .into_iter()
4057        .map(|id| id.to_proto())
4058        .collect();
4059    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4060
4061    for connection_id in connection_pool.user_connection_ids(user_id) {
4062        peer.send(connection_id, user_channels_update.clone())
4063            .trace_err();
4064        peer.send(connection_id, update.clone()).trace_err();
4065    }
4066}
4067
4068fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4069    proto::UpdateUserChannels {
4070        channel_memberships: channels
4071            .channel_memberships
4072            .iter()
4073            .map(|m| proto::ChannelMembership {
4074                channel_id: m.channel_id.to_proto(),
4075                role: m.role.into(),
4076            })
4077            .collect(),
4078        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4079        observed_channel_message_id: channels.observed_channel_messages.clone(),
4080    }
4081}
4082
4083fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4084    let mut update = proto::UpdateChannels::default();
4085
4086    for channel in channels.channels {
4087        update.channels.push(channel.to_proto());
4088    }
4089
4090    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4091    update.latest_channel_message_ids = channels.latest_channel_messages;
4092
4093    for (channel_id, participants) in channels.channel_participants {
4094        update
4095            .channel_participants
4096            .push(proto::ChannelParticipants {
4097                channel_id: channel_id.to_proto(),
4098                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4099            });
4100    }
4101
4102    for channel in channels.invited_channels {
4103        update.channel_invitations.push(channel.to_proto());
4104    }
4105
4106    update
4107}
4108
4109fn build_initial_contacts_update(
4110    contacts: Vec<db::Contact>,
4111    pool: &ConnectionPool,
4112) -> proto::UpdateContacts {
4113    let mut update = proto::UpdateContacts::default();
4114
4115    for contact in contacts {
4116        match contact {
4117            db::Contact::Accepted { user_id, busy } => {
4118                update.contacts.push(contact_for_user(user_id, busy, pool));
4119            }
4120            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4121            db::Contact::Incoming { user_id } => {
4122                update
4123                    .incoming_requests
4124                    .push(proto::IncomingContactRequest {
4125                        requester_id: user_id.to_proto(),
4126                    })
4127            }
4128        }
4129    }
4130
4131    update
4132}
4133
4134fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4135    proto::Contact {
4136        user_id: user_id.to_proto(),
4137        online: pool.is_user_online(user_id),
4138        busy,
4139    }
4140}
4141
4142fn room_updated(room: &proto::Room, peer: &Peer) {
4143    broadcast(
4144        None,
4145        room.participants
4146            .iter()
4147            .filter_map(|participant| Some(participant.peer_id?.into())),
4148        |peer_id| {
4149            peer.send(
4150                peer_id,
4151                proto::RoomUpdated {
4152                    room: Some(room.clone()),
4153                },
4154            )
4155        },
4156    );
4157}
4158
4159fn channel_updated(
4160    channel: &db::channel::Model,
4161    room: &proto::Room,
4162    peer: &Peer,
4163    pool: &ConnectionPool,
4164) {
4165    let participants = room
4166        .participants
4167        .iter()
4168        .map(|p| p.user_id)
4169        .collect::<Vec<_>>();
4170
4171    broadcast(
4172        None,
4173        pool.channel_connection_ids(channel.root_id())
4174            .filter_map(|(channel_id, role)| {
4175                role.can_see_channel(channel.visibility)
4176                    .then_some(channel_id)
4177            }),
4178        |peer_id| {
4179            peer.send(
4180                peer_id,
4181                proto::UpdateChannels {
4182                    channel_participants: vec![proto::ChannelParticipants {
4183                        channel_id: channel.id.to_proto(),
4184                        participant_user_ids: participants.clone(),
4185                    }],
4186                    ..Default::default()
4187                },
4188            )
4189        },
4190    );
4191}
4192
4193async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4194    let db = session.db().await;
4195
4196    let contacts = db.get_contacts(user_id).await?;
4197    let busy = db.is_user_busy(user_id).await?;
4198
4199    let pool = session.connection_pool().await;
4200    let updated_contact = contact_for_user(user_id, busy, &pool);
4201    for contact in contacts {
4202        if let db::Contact::Accepted {
4203            user_id: contact_user_id,
4204            ..
4205        } = contact
4206        {
4207            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4208                session
4209                    .peer
4210                    .send(
4211                        contact_conn_id,
4212                        proto::UpdateContacts {
4213                            contacts: vec![updated_contact.clone()],
4214                            remove_contacts: Default::default(),
4215                            incoming_requests: Default::default(),
4216                            remove_incoming_requests: Default::default(),
4217                            outgoing_requests: Default::default(),
4218                            remove_outgoing_requests: Default::default(),
4219                        },
4220                    )
4221                    .trace_err();
4222            }
4223        }
4224    }
4225    Ok(())
4226}
4227
4228async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4229    let mut contacts_to_update = HashSet::default();
4230
4231    let room_id;
4232    let canceled_calls_to_user_ids;
4233    let livekit_room;
4234    let delete_livekit_room;
4235    let room;
4236    let channel;
4237
4238    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4239        contacts_to_update.insert(session.user_id());
4240
4241        for project in left_room.left_projects.values() {
4242            project_left(project, session);
4243        }
4244
4245        room_id = RoomId::from_proto(left_room.room.id);
4246        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4247        livekit_room = mem::take(&mut left_room.room.livekit_room);
4248        delete_livekit_room = left_room.deleted;
4249        room = mem::take(&mut left_room.room);
4250        channel = mem::take(&mut left_room.channel);
4251
4252        room_updated(&room, &session.peer);
4253    } else {
4254        return Ok(());
4255    }
4256
4257    if let Some(channel) = channel {
4258        channel_updated(
4259            &channel,
4260            &room,
4261            &session.peer,
4262            &*session.connection_pool().await,
4263        );
4264    }
4265
4266    {
4267        let pool = session.connection_pool().await;
4268        for canceled_user_id in canceled_calls_to_user_ids {
4269            for connection_id in pool.user_connection_ids(canceled_user_id) {
4270                session
4271                    .peer
4272                    .send(
4273                        connection_id,
4274                        proto::CallCanceled {
4275                            room_id: room_id.to_proto(),
4276                        },
4277                    )
4278                    .trace_err();
4279            }
4280            contacts_to_update.insert(canceled_user_id);
4281        }
4282    }
4283
4284    for contact_user_id in contacts_to_update {
4285        update_user_contacts(contact_user_id, session).await?;
4286    }
4287
4288    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4289        live_kit
4290            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4291            .await
4292            .trace_err();
4293
4294        if delete_livekit_room {
4295            live_kit.delete_room(livekit_room).await.trace_err();
4296        }
4297    }
4298
4299    Ok(())
4300}
4301
4302async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4303    let left_channel_buffers = session
4304        .db()
4305        .await
4306        .leave_channel_buffers(session.connection_id)
4307        .await?;
4308
4309    for left_buffer in left_channel_buffers {
4310        channel_buffer_updated(
4311            session.connection_id,
4312            left_buffer.connections,
4313            &proto::UpdateChannelBufferCollaborators {
4314                channel_id: left_buffer.channel_id.to_proto(),
4315                collaborators: left_buffer.collaborators,
4316            },
4317            &session.peer,
4318        );
4319    }
4320
4321    Ok(())
4322}
4323
4324fn project_left(project: &db::LeftProject, session: &Session) {
4325    for connection_id in &project.connection_ids {
4326        if project.should_unshare {
4327            session
4328                .peer
4329                .send(
4330                    *connection_id,
4331                    proto::UnshareProject {
4332                        project_id: project.id.to_proto(),
4333                    },
4334                )
4335                .trace_err();
4336        } else {
4337            session
4338                .peer
4339                .send(
4340                    *connection_id,
4341                    proto::RemoveProjectCollaborator {
4342                        project_id: project.id.to_proto(),
4343                        peer_id: Some(session.connection_id.into()),
4344                    },
4345                )
4346                .trace_err();
4347        }
4348    }
4349}
4350
4351pub trait ResultExt {
4352    type Ok;
4353
4354    fn trace_err(self) -> Option<Self::Ok>;
4355}
4356
4357impl<T, E> ResultExt for Result<T, E>
4358where
4359    E: std::fmt::Debug,
4360{
4361    type Ok = T;
4362
4363    #[track_caller]
4364    fn trace_err(self) -> Option<T> {
4365        match self {
4366            Ok(value) => Some(value),
4367            Err(error) => {
4368                tracing::error!("{:?}", error);
4369                None
4370            }
4371        }
4372    }
4373}