rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   8        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   9        NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
  10        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use chrono::Utc;
  33use collections::{HashMap, HashSet};
  34pub use connection_pool::{ConnectionPool, ZedVersion};
  35use core::fmt::{self, Debug, Formatter};
  36use futures::TryFutureExt as _;
  37use reqwest_client::ReqwestClient;
  38use rpc::proto::{MultiLspQuery, split_repository_update};
  39use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  40use tracing::Span;
  41
  42use futures::{
  43    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  44    stream::FuturesUnordered,
  45};
  46use prometheus::{IntGauge, register_int_gauge};
  47use rpc::{
  48    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        Arc, OnceLock,
  66        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{Semaphore, watch};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    Instrument,
  75    field::{self},
  76    info_span, instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  88
  89static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  90
  91const TOTAL_DURATION_MS: &str = "total_duration_ms";
  92const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  93const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  94const HOST_WAITING_MS: &str = "host_waiting_ms";
  95
  96type MessageHandler =
  97    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  98
  99pub struct ConnectionGuard;
 100
 101impl ConnectionGuard {
 102    pub fn try_acquire() -> Result<Self, ()> {
 103        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
 104        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 105            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 106            tracing::error!(
 107                "too many concurrent connections: {}",
 108                current_connections + 1
 109            );
 110            return Err(());
 111        }
 112        Ok(ConnectionGuard)
 113    }
 114}
 115
 116impl Drop for ConnectionGuard {
 117    fn drop(&mut self) {
 118        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 119    }
 120}
 121
 122struct Response<R> {
 123    peer: Arc<Peer>,
 124    receipt: Receipt<R>,
 125    responded: Arc<AtomicBool>,
 126}
 127
 128impl<R: RequestMessage> Response<R> {
 129    fn send(self, payload: R::Response) -> Result<()> {
 130        self.responded.store(true, SeqCst);
 131        self.peer.respond(self.receipt, payload)?;
 132        Ok(())
 133    }
 134}
 135
 136#[derive(Clone, Debug)]
 137pub enum Principal {
 138    User(User),
 139    Impersonated { user: User, admin: User },
 140}
 141
 142impl Principal {
 143    fn update_span(&self, span: &tracing::Span) {
 144        match &self {
 145            Principal::User(user) => {
 146                span.record("user_id", user.id.0);
 147                span.record("login", &user.github_login);
 148            }
 149            Principal::Impersonated { user, admin } => {
 150                span.record("user_id", user.id.0);
 151                span.record("login", &user.github_login);
 152                span.record("impersonator", &admin.github_login);
 153            }
 154        }
 155    }
 156}
 157
 158#[derive(Clone)]
 159struct MessageContext {
 160    session: Session,
 161    span: tracing::Span,
 162}
 163
 164impl Deref for MessageContext {
 165    type Target = Session;
 166
 167    fn deref(&self) -> &Self::Target {
 168        &self.session
 169    }
 170}
 171
 172impl MessageContext {
 173    pub fn forward_request<T: RequestMessage>(
 174        &self,
 175        receiver_id: ConnectionId,
 176        request: T,
 177    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 178        let request_start_time = Instant::now();
 179        let span = self.span.clone();
 180        tracing::info!("start forwarding request");
 181        self.peer
 182            .forward_request(self.connection_id, receiver_id, request)
 183            .inspect(move |_| {
 184                span.record(
 185                    HOST_WAITING_MS,
 186                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 187                );
 188            })
 189            .inspect_err(|_| tracing::error!("error forwarding request"))
 190            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 191    }
 192}
 193
 194#[derive(Clone)]
 195struct Session {
 196    principal: Principal,
 197    connection_id: ConnectionId,
 198    db: Arc<tokio::sync::Mutex<DbHandle>>,
 199    peer: Arc<Peer>,
 200    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 201    app_state: Arc<AppState>,
 202    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 203    /// The GeoIP country code for the user.
 204    #[allow(unused)]
 205    geoip_country_code: Option<String>,
 206    #[allow(unused)]
 207    system_id: Option<String>,
 208    _executor: Executor,
 209}
 210
 211impl Session {
 212    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 213        #[cfg(test)]
 214        tokio::task::yield_now().await;
 215        let guard = self.db.lock().await;
 216        #[cfg(test)]
 217        tokio::task::yield_now().await;
 218        guard
 219    }
 220
 221    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 222        #[cfg(test)]
 223        tokio::task::yield_now().await;
 224        let guard = self.connection_pool.lock();
 225        ConnectionPoolGuard {
 226            guard,
 227            _not_send: PhantomData,
 228        }
 229    }
 230
 231    fn is_staff(&self) -> bool {
 232        match &self.principal {
 233            Principal::User(user) => user.admin,
 234            Principal::Impersonated { .. } => true,
 235        }
 236    }
 237
 238    fn user_id(&self) -> UserId {
 239        match &self.principal {
 240            Principal::User(user) => user.id,
 241            Principal::Impersonated { user, .. } => user.id,
 242        }
 243    }
 244
 245    pub fn email(&self) -> Option<String> {
 246        match &self.principal {
 247            Principal::User(user) => user.email_address.clone(),
 248            Principal::Impersonated { user, .. } => user.email_address.clone(),
 249        }
 250    }
 251}
 252
 253impl Debug for Session {
 254    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 255        let mut result = f.debug_struct("Session");
 256        match &self.principal {
 257            Principal::User(user) => {
 258                result.field("user", &user.github_login);
 259            }
 260            Principal::Impersonated { user, admin } => {
 261                result.field("user", &user.github_login);
 262                result.field("impersonator", &admin.github_login);
 263            }
 264        }
 265        result.field("connection_id", &self.connection_id).finish()
 266    }
 267}
 268
 269struct DbHandle(Arc<Database>);
 270
 271impl Deref for DbHandle {
 272    type Target = Database;
 273
 274    fn deref(&self) -> &Self::Target {
 275        self.0.as_ref()
 276    }
 277}
 278
 279pub struct Server {
 280    id: parking_lot::Mutex<ServerId>,
 281    peer: Arc<Peer>,
 282    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 283    app_state: Arc<AppState>,
 284    handlers: HashMap<TypeId, MessageHandler>,
 285    teardown: watch::Sender<bool>,
 286}
 287
 288pub(crate) struct ConnectionPoolGuard<'a> {
 289    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 290    _not_send: PhantomData<Rc<()>>,
 291}
 292
 293#[derive(Serialize)]
 294pub struct ServerSnapshot<'a> {
 295    peer: &'a Peer,
 296    #[serde(serialize_with = "serialize_deref")]
 297    connection_pool: ConnectionPoolGuard<'a>,
 298}
 299
 300pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 301where
 302    S: Serializer,
 303    T: Deref<Target = U>,
 304    U: Serialize,
 305{
 306    Serialize::serialize(value.deref(), serializer)
 307}
 308
 309impl Server {
 310    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 311        let mut server = Self {
 312            id: parking_lot::Mutex::new(id),
 313            peer: Peer::new(id.0 as u32),
 314            app_state: app_state.clone(),
 315            connection_pool: Default::default(),
 316            handlers: Default::default(),
 317            teardown: watch::channel(false).0,
 318        };
 319
 320        server
 321            .add_request_handler(ping)
 322            .add_request_handler(create_room)
 323            .add_request_handler(join_room)
 324            .add_request_handler(rejoin_room)
 325            .add_request_handler(leave_room)
 326            .add_request_handler(set_room_participant_role)
 327            .add_request_handler(call)
 328            .add_request_handler(cancel_call)
 329            .add_message_handler(decline_call)
 330            .add_request_handler(update_participant_location)
 331            .add_request_handler(share_project)
 332            .add_message_handler(unshare_project)
 333            .add_request_handler(join_project)
 334            .add_message_handler(leave_project)
 335            .add_request_handler(update_project)
 336            .add_request_handler(update_worktree)
 337            .add_request_handler(update_repository)
 338            .add_request_handler(remove_repository)
 339            .add_message_handler(start_language_server)
 340            .add_message_handler(update_language_server)
 341            .add_message_handler(update_diagnostic_summary)
 342            .add_message_handler(update_worktree_settings)
 343            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 344            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 345            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 346            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 347            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 348            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 349            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 350            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 351            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 352            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 353            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 354            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 355            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 356            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 357            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 358            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 359            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 360            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 361            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 362            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 363            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 364            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 365            .add_request_handler(
 366                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 367            )
 368            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 369            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 370            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 371            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 372            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 373            .add_request_handler(
 374                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 375            )
 376            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 377            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 378            .add_request_handler(
 379                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 380            )
 381            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 382            .add_request_handler(
 383                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 384            )
 385            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 386            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 387            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 388            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 389            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 390            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 391            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 392            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 393            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 394            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 395            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 396            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 397            .add_request_handler(
 398                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 399            )
 400            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 401            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 402            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 403            .add_request_handler(multi_lsp_query)
 404            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 405            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 406            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 407            .add_message_handler(create_buffer_for_peer)
 408            .add_request_handler(update_buffer)
 409            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 410            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 411            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 412            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 413            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 414            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 415            .add_message_handler(
 416                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 417            )
 418            .add_request_handler(get_users)
 419            .add_request_handler(fuzzy_search_users)
 420            .add_request_handler(request_contact)
 421            .add_request_handler(remove_contact)
 422            .add_request_handler(respond_to_contact_request)
 423            .add_message_handler(subscribe_to_channels)
 424            .add_request_handler(create_channel)
 425            .add_request_handler(delete_channel)
 426            .add_request_handler(invite_channel_member)
 427            .add_request_handler(remove_channel_member)
 428            .add_request_handler(set_channel_member_role)
 429            .add_request_handler(set_channel_visibility)
 430            .add_request_handler(rename_channel)
 431            .add_request_handler(join_channel_buffer)
 432            .add_request_handler(leave_channel_buffer)
 433            .add_message_handler(update_channel_buffer)
 434            .add_request_handler(rejoin_channel_buffers)
 435            .add_request_handler(get_channel_members)
 436            .add_request_handler(respond_to_channel_invite)
 437            .add_request_handler(join_channel)
 438            .add_request_handler(join_channel_chat)
 439            .add_message_handler(leave_channel_chat)
 440            .add_request_handler(send_channel_message)
 441            .add_request_handler(remove_channel_message)
 442            .add_request_handler(update_channel_message)
 443            .add_request_handler(get_channel_messages)
 444            .add_request_handler(get_channel_messages_by_id)
 445            .add_request_handler(get_notifications)
 446            .add_request_handler(mark_notification_as_read)
 447            .add_request_handler(move_channel)
 448            .add_request_handler(reorder_channel)
 449            .add_request_handler(follow)
 450            .add_message_handler(unfollow)
 451            .add_message_handler(update_followers)
 452            .add_request_handler(accept_terms_of_service)
 453            .add_message_handler(acknowledge_channel_message)
 454            .add_message_handler(acknowledge_buffer_version)
 455            .add_request_handler(get_supermaven_api_key)
 456            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 457            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 458            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 459            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 460            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 461            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 462            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 463            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 464            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 465            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 466            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 467            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 468            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 469            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 470            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 471            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 472            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 473            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 474            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 475            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 476            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 477            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 478            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 479            .add_message_handler(update_context);
 480
 481        Arc::new(server)
 482    }
 483
 484    pub async fn start(&self) -> Result<()> {
 485        let server_id = *self.id.lock();
 486        let app_state = self.app_state.clone();
 487        let peer = self.peer.clone();
 488        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 489        let pool = self.connection_pool.clone();
 490        let livekit_client = self.app_state.livekit_client.clone();
 491
 492        let span = info_span!("start server");
 493        self.app_state.executor.spawn_detached(
 494            async move {
 495                tracing::info!("waiting for cleanup timeout");
 496                timeout.await;
 497                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 498
 499                app_state
 500                    .db
 501                    .delete_stale_channel_chat_participants(
 502                        &app_state.config.zed_environment,
 503                        server_id,
 504                    )
 505                    .await
 506                    .trace_err();
 507
 508                if let Some((room_ids, channel_ids)) = app_state
 509                    .db
 510                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 511                    .await
 512                    .trace_err()
 513                {
 514                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 515                    tracing::info!(
 516                        stale_channel_buffer_count = channel_ids.len(),
 517                        "retrieved stale channel buffers"
 518                    );
 519
 520                    for channel_id in channel_ids {
 521                        if let Some(refreshed_channel_buffer) = app_state
 522                            .db
 523                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 524                            .await
 525                            .trace_err()
 526                        {
 527                            for connection_id in refreshed_channel_buffer.connection_ids {
 528                                peer.send(
 529                                    connection_id,
 530                                    proto::UpdateChannelBufferCollaborators {
 531                                        channel_id: channel_id.to_proto(),
 532                                        collaborators: refreshed_channel_buffer
 533                                            .collaborators
 534                                            .clone(),
 535                                    },
 536                                )
 537                                .trace_err();
 538                            }
 539                        }
 540                    }
 541
 542                    for room_id in room_ids {
 543                        let mut contacts_to_update = HashSet::default();
 544                        let mut canceled_calls_to_user_ids = Vec::new();
 545                        let mut livekit_room = String::new();
 546                        let mut delete_livekit_room = false;
 547
 548                        if let Some(mut refreshed_room) = app_state
 549                            .db
 550                            .clear_stale_room_participants(room_id, server_id)
 551                            .await
 552                            .trace_err()
 553                        {
 554                            tracing::info!(
 555                                room_id = room_id.0,
 556                                new_participant_count = refreshed_room.room.participants.len(),
 557                                "refreshed room"
 558                            );
 559                            room_updated(&refreshed_room.room, &peer);
 560                            if let Some(channel) = refreshed_room.channel.as_ref() {
 561                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 562                            }
 563                            contacts_to_update
 564                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 565                            contacts_to_update
 566                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 567                            canceled_calls_to_user_ids =
 568                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 569                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 570                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 571                        }
 572
 573                        {
 574                            let pool = pool.lock();
 575                            for canceled_user_id in canceled_calls_to_user_ids {
 576                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 577                                    peer.send(
 578                                        connection_id,
 579                                        proto::CallCanceled {
 580                                            room_id: room_id.to_proto(),
 581                                        },
 582                                    )
 583                                    .trace_err();
 584                                }
 585                            }
 586                        }
 587
 588                        for user_id in contacts_to_update {
 589                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 590                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 591                            if let Some((busy, contacts)) = busy.zip(contacts) {
 592                                let pool = pool.lock();
 593                                let updated_contact = contact_for_user(user_id, busy, &pool);
 594                                for contact in contacts {
 595                                    if let db::Contact::Accepted {
 596                                        user_id: contact_user_id,
 597                                        ..
 598                                    } = contact
 599                                    {
 600                                        for contact_conn_id in
 601                                            pool.user_connection_ids(contact_user_id)
 602                                        {
 603                                            peer.send(
 604                                                contact_conn_id,
 605                                                proto::UpdateContacts {
 606                                                    contacts: vec![updated_contact.clone()],
 607                                                    remove_contacts: Default::default(),
 608                                                    incoming_requests: Default::default(),
 609                                                    remove_incoming_requests: Default::default(),
 610                                                    outgoing_requests: Default::default(),
 611                                                    remove_outgoing_requests: Default::default(),
 612                                                },
 613                                            )
 614                                            .trace_err();
 615                                        }
 616                                    }
 617                                }
 618                            }
 619                        }
 620
 621                        if let Some(live_kit) = livekit_client.as_ref() {
 622                            if delete_livekit_room {
 623                                live_kit.delete_room(livekit_room).await.trace_err();
 624                            }
 625                        }
 626                    }
 627                }
 628
 629                app_state
 630                    .db
 631                    .delete_stale_channel_chat_participants(
 632                        &app_state.config.zed_environment,
 633                        server_id,
 634                    )
 635                    .await
 636                    .trace_err();
 637
 638                app_state
 639                    .db
 640                    .clear_old_worktree_entries(server_id)
 641                    .await
 642                    .trace_err();
 643
 644                app_state
 645                    .db
 646                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 647                    .await
 648                    .trace_err();
 649            }
 650            .instrument(span),
 651        );
 652        Ok(())
 653    }
 654
 655    pub fn teardown(&self) {
 656        self.peer.teardown();
 657        self.connection_pool.lock().reset();
 658        let _ = self.teardown.send(true);
 659    }
 660
 661    #[cfg(test)]
 662    pub fn reset(&self, id: ServerId) {
 663        self.teardown();
 664        *self.id.lock() = id;
 665        self.peer.reset(id.0 as u32);
 666        let _ = self.teardown.send(false);
 667    }
 668
 669    #[cfg(test)]
 670    pub fn id(&self) -> ServerId {
 671        *self.id.lock()
 672    }
 673
 674    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 675    where
 676        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 677        Fut: 'static + Send + Future<Output = Result<()>>,
 678        M: EnvelopedMessage,
 679    {
 680        let prev_handler = self.handlers.insert(
 681            TypeId::of::<M>(),
 682            Box::new(move |envelope, session, span| {
 683                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 684                let received_at = envelope.received_at;
 685                tracing::info!("message received");
 686                let start_time = Instant::now();
 687                let future = (handler)(
 688                    *envelope,
 689                    MessageContext {
 690                        session,
 691                        span: span.clone(),
 692                    },
 693                );
 694                async move {
 695                    let result = future.await;
 696                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 697                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 698                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 699                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 700                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 701                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 702                    match result {
 703                        Err(error) => {
 704                            tracing::error!(?error, "error handling message")
 705                        }
 706                        Ok(()) => tracing::info!("finished handling message"),
 707                    }
 708                }
 709                .boxed()
 710            }),
 711        );
 712        if prev_handler.is_some() {
 713            panic!("registered a handler for the same message twice");
 714        }
 715        self
 716    }
 717
 718    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 719    where
 720        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 721        Fut: 'static + Send + Future<Output = Result<()>>,
 722        M: EnvelopedMessage,
 723    {
 724        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 725        self
 726    }
 727
 728    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 729    where
 730        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 731        Fut: Send + Future<Output = Result<()>>,
 732        M: RequestMessage,
 733    {
 734        let handler = Arc::new(handler);
 735        self.add_handler(move |envelope, session| {
 736            let receipt = envelope.receipt();
 737            let handler = handler.clone();
 738            async move {
 739                let peer = session.peer.clone();
 740                let responded = Arc::new(AtomicBool::default());
 741                let response = Response {
 742                    peer: peer.clone(),
 743                    responded: responded.clone(),
 744                    receipt,
 745                };
 746                match (handler)(envelope.payload, response, session).await {
 747                    Ok(()) => {
 748                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 749                            Ok(())
 750                        } else {
 751                            Err(anyhow!("handler did not send a response"))?
 752                        }
 753                    }
 754                    Err(error) => {
 755                        let proto_err = match &error {
 756                            Error::Internal(err) => err.to_proto(),
 757                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 758                        };
 759                        peer.respond_with_error(receipt, proto_err)?;
 760                        Err(error)
 761                    }
 762                }
 763            }
 764        })
 765    }
 766
 767    pub fn handle_connection(
 768        self: &Arc<Self>,
 769        connection: Connection,
 770        address: String,
 771        principal: Principal,
 772        zed_version: ZedVersion,
 773        release_channel: Option<String>,
 774        user_agent: Option<String>,
 775        geoip_country_code: Option<String>,
 776        system_id: Option<String>,
 777        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 778        executor: Executor,
 779        connection_guard: Option<ConnectionGuard>,
 780    ) -> impl Future<Output = ()> + use<> {
 781        let this = self.clone();
 782        let span = info_span!("handle connection", %address,
 783            connection_id=field::Empty,
 784            user_id=field::Empty,
 785            login=field::Empty,
 786            impersonator=field::Empty,
 787            user_agent=field::Empty,
 788            geoip_country_code=field::Empty,
 789            release_channel=field::Empty,
 790        );
 791        principal.update_span(&span);
 792        if let Some(user_agent) = user_agent {
 793            span.record("user_agent", user_agent);
 794        }
 795        if let Some(release_channel) = release_channel {
 796            span.record("release_channel", release_channel);
 797        }
 798
 799        if let Some(country_code) = geoip_country_code.as_ref() {
 800            span.record("geoip_country_code", country_code);
 801        }
 802
 803        let mut teardown = self.teardown.subscribe();
 804        async move {
 805            if *teardown.borrow() {
 806                tracing::error!("server is tearing down");
 807                return;
 808            }
 809
 810            let (connection_id, handle_io, mut incoming_rx) =
 811                this.peer.add_connection(connection, {
 812                    let executor = executor.clone();
 813                    move |duration| executor.sleep(duration)
 814                });
 815            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 816
 817            tracing::info!("connection opened");
 818
 819            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 820            let http_client = match ReqwestClient::user_agent(&user_agent) {
 821                Ok(http_client) => Arc::new(http_client),
 822                Err(error) => {
 823                    tracing::error!(?error, "failed to create HTTP client");
 824                    return;
 825                }
 826            };
 827
 828            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 829                |supermaven_admin_api_key| {
 830                    Arc::new(SupermavenAdminApi::new(
 831                        supermaven_admin_api_key.to_string(),
 832                        http_client.clone(),
 833                    ))
 834                },
 835            );
 836
 837            let session = Session {
 838                principal: principal.clone(),
 839                connection_id,
 840                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 841                peer: this.peer.clone(),
 842                connection_pool: this.connection_pool.clone(),
 843                app_state: this.app_state.clone(),
 844                geoip_country_code,
 845                system_id,
 846                _executor: executor.clone(),
 847                supermaven_client,
 848            };
 849
 850            if let Err(error) = this
 851                .send_initial_client_update(
 852                    connection_id,
 853                    zed_version,
 854                    send_connection_id,
 855                    &session,
 856                )
 857                .await
 858            {
 859                tracing::error!(?error, "failed to send initial client update");
 860                return;
 861            }
 862            drop(connection_guard);
 863
 864            let handle_io = handle_io.fuse();
 865            futures::pin_mut!(handle_io);
 866
 867            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 868            // This prevents deadlocks when e.g., client A performs a request to client B and
 869            // client B performs a request to client A. If both clients stop processing further
 870            // messages until their respective request completes, they won't have a chance to
 871            // respond to the other client's request and cause a deadlock.
 872            //
 873            // This arrangement ensures we will attempt to process earlier messages first, but fall
 874            // back to processing messages arrived later in the spirit of making progress.
 875            const MAX_CONCURRENT_HANDLERS: usize = 256;
 876            let mut foreground_message_handlers = FuturesUnordered::new();
 877            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 878            let get_concurrent_handlers = {
 879                let concurrent_handlers = concurrent_handlers.clone();
 880                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 881            };
 882            loop {
 883                let next_message = async {
 884                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 885                    let message = incoming_rx.next().await;
 886                    // Cache the concurrent_handlers here, so that we know what the
 887                    // queue looks like as each handler starts
 888                    (permit, message, get_concurrent_handlers())
 889                }
 890                .fuse();
 891                futures::pin_mut!(next_message);
 892                futures::select_biased! {
 893                    _ = teardown.changed().fuse() => return,
 894                    result = handle_io => {
 895                        if let Err(error) = result {
 896                            tracing::error!(?error, "error handling I/O");
 897                        }
 898                        break;
 899                    }
 900                    _ = foreground_message_handlers.next() => {}
 901                    next_message = next_message => {
 902                        let (permit, message, concurrent_handlers) = next_message;
 903                        if let Some(message) = message {
 904                            let type_name = message.payload_type_name();
 905                            // note: we copy all the fields from the parent span so we can query them in the logs.
 906                            // (https://github.com/tokio-rs/tracing/issues/2670).
 907                            let span = tracing::info_span!("receive message",
 908                                %connection_id,
 909                                %address,
 910                                type_name,
 911                                concurrent_handlers,
 912                                user_id=field::Empty,
 913                                login=field::Empty,
 914                                impersonator=field::Empty,
 915                                multi_lsp_query_request=field::Empty,
 916                                release_channel=field::Empty,
 917                                { TOTAL_DURATION_MS }=field::Empty,
 918                                { PROCESSING_DURATION_MS }=field::Empty,
 919                                { QUEUE_DURATION_MS }=field::Empty,
 920                                { HOST_WAITING_MS }=field::Empty
 921                            );
 922                            principal.update_span(&span);
 923                            let span_enter = span.enter();
 924                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 925                                let is_background = message.is_background();
 926                                let handle_message = (handler)(message, session.clone(), span.clone());
 927                                drop(span_enter);
 928
 929                                let handle_message = async move {
 930                                    handle_message.await;
 931                                    drop(permit);
 932                                }.instrument(span);
 933                                if is_background {
 934                                    executor.spawn_detached(handle_message);
 935                                } else {
 936                                    foreground_message_handlers.push(handle_message);
 937                                }
 938                            } else {
 939                                tracing::error!("no message handler");
 940                            }
 941                        } else {
 942                            tracing::info!("connection closed");
 943                            break;
 944                        }
 945                    }
 946                }
 947            }
 948
 949            drop(foreground_message_handlers);
 950            let concurrent_handlers = get_concurrent_handlers();
 951            tracing::info!(concurrent_handlers, "signing out");
 952            if let Err(error) = connection_lost(session, teardown, executor).await {
 953                tracing::error!(?error, "error signing out");
 954            }
 955        }
 956        .instrument(span)
 957    }
 958
 959    async fn send_initial_client_update(
 960        &self,
 961        connection_id: ConnectionId,
 962        zed_version: ZedVersion,
 963        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 964        session: &Session,
 965    ) -> Result<()> {
 966        self.peer.send(
 967            connection_id,
 968            proto::Hello {
 969                peer_id: Some(connection_id.into()),
 970            },
 971        )?;
 972        tracing::info!("sent hello message");
 973        if let Some(send_connection_id) = send_connection_id.take() {
 974            let _ = send_connection_id.send(connection_id);
 975        }
 976
 977        match &session.principal {
 978            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 979                if !user.connected_once {
 980                    self.peer.send(connection_id, proto::ShowContacts {})?;
 981                    self.app_state
 982                        .db
 983                        .set_user_connected_once(user.id, true)
 984                        .await?;
 985                }
 986
 987                let contacts = self.app_state.db.get_contacts(user.id).await?;
 988
 989                {
 990                    let mut pool = self.connection_pool.lock();
 991                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 992                    self.peer.send(
 993                        connection_id,
 994                        build_initial_contacts_update(contacts, &pool),
 995                    )?;
 996                }
 997
 998                if should_auto_subscribe_to_channels(zed_version) {
 999                    subscribe_user_to_channels(user.id, session).await?;
1000                }
1001
1002                if let Some(incoming_call) =
1003                    self.app_state.db.incoming_call_for_user(user.id).await?
1004                {
1005                    self.peer.send(connection_id, incoming_call)?;
1006                }
1007
1008                update_user_contacts(user.id, session).await?;
1009            }
1010        }
1011
1012        Ok(())
1013    }
1014
1015    pub async fn invite_code_redeemed(
1016        self: &Arc<Self>,
1017        inviter_id: UserId,
1018        invitee_id: UserId,
1019    ) -> Result<()> {
1020        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1021            if let Some(code) = &user.invite_code {
1022                let pool = self.connection_pool.lock();
1023                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1024                for connection_id in pool.user_connection_ids(inviter_id) {
1025                    self.peer.send(
1026                        connection_id,
1027                        proto::UpdateContacts {
1028                            contacts: vec![invitee_contact.clone()],
1029                            ..Default::default()
1030                        },
1031                    )?;
1032                    self.peer.send(
1033                        connection_id,
1034                        proto::UpdateInviteInfo {
1035                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1036                            count: user.invite_count as u32,
1037                        },
1038                    )?;
1039                }
1040            }
1041        }
1042        Ok(())
1043    }
1044
1045    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1046        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1047            if let Some(invite_code) = &user.invite_code {
1048                let pool = self.connection_pool.lock();
1049                for connection_id in pool.user_connection_ids(user_id) {
1050                    self.peer.send(
1051                        connection_id,
1052                        proto::UpdateInviteInfo {
1053                            url: format!(
1054                                "{}{}",
1055                                self.app_state.config.invite_link_prefix, invite_code
1056                            ),
1057                            count: user.invite_count as u32,
1058                        },
1059                    )?;
1060                }
1061            }
1062        }
1063        Ok(())
1064    }
1065
1066    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1067        ServerSnapshot {
1068            connection_pool: ConnectionPoolGuard {
1069                guard: self.connection_pool.lock(),
1070                _not_send: PhantomData,
1071            },
1072            peer: &self.peer,
1073        }
1074    }
1075}
1076
1077impl Deref for ConnectionPoolGuard<'_> {
1078    type Target = ConnectionPool;
1079
1080    fn deref(&self) -> &Self::Target {
1081        &self.guard
1082    }
1083}
1084
1085impl DerefMut for ConnectionPoolGuard<'_> {
1086    fn deref_mut(&mut self) -> &mut Self::Target {
1087        &mut self.guard
1088    }
1089}
1090
1091impl Drop for ConnectionPoolGuard<'_> {
1092    fn drop(&mut self) {
1093        #[cfg(test)]
1094        self.check_invariants();
1095    }
1096}
1097
1098fn broadcast<F>(
1099    sender_id: Option<ConnectionId>,
1100    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1101    mut f: F,
1102) where
1103    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1104{
1105    for receiver_id in receiver_ids {
1106        if Some(receiver_id) != sender_id {
1107            if let Err(error) = f(receiver_id) {
1108                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1109            }
1110        }
1111    }
1112}
1113
1114pub struct ProtocolVersion(u32);
1115
1116impl Header for ProtocolVersion {
1117    fn name() -> &'static HeaderName {
1118        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1119        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1120    }
1121
1122    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1123    where
1124        Self: Sized,
1125        I: Iterator<Item = &'i axum::http::HeaderValue>,
1126    {
1127        let version = values
1128            .next()
1129            .ok_or_else(axum::headers::Error::invalid)?
1130            .to_str()
1131            .map_err(|_| axum::headers::Error::invalid())?
1132            .parse()
1133            .map_err(|_| axum::headers::Error::invalid())?;
1134        Ok(Self(version))
1135    }
1136
1137    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1138        values.extend([self.0.to_string().parse().unwrap()]);
1139    }
1140}
1141
1142pub struct AppVersionHeader(SemanticVersion);
1143impl Header for AppVersionHeader {
1144    fn name() -> &'static HeaderName {
1145        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1146        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1147    }
1148
1149    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1150    where
1151        Self: Sized,
1152        I: Iterator<Item = &'i axum::http::HeaderValue>,
1153    {
1154        let version = values
1155            .next()
1156            .ok_or_else(axum::headers::Error::invalid)?
1157            .to_str()
1158            .map_err(|_| axum::headers::Error::invalid())?
1159            .parse()
1160            .map_err(|_| axum::headers::Error::invalid())?;
1161        Ok(Self(version))
1162    }
1163
1164    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1165        values.extend([self.0.to_string().parse().unwrap()]);
1166    }
1167}
1168
1169#[derive(Debug)]
1170pub struct ReleaseChannelHeader(String);
1171
1172impl Header for ReleaseChannelHeader {
1173    fn name() -> &'static HeaderName {
1174        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1175        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1176    }
1177
1178    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1179    where
1180        Self: Sized,
1181        I: Iterator<Item = &'i axum::http::HeaderValue>,
1182    {
1183        Ok(Self(
1184            values
1185                .next()
1186                .ok_or_else(axum::headers::Error::invalid)?
1187                .to_str()
1188                .map_err(|_| axum::headers::Error::invalid())?
1189                .to_owned(),
1190        ))
1191    }
1192
1193    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1194        values.extend([self.0.parse().unwrap()]);
1195    }
1196}
1197
1198pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1199    Router::new()
1200        .route("/rpc", get(handle_websocket_request))
1201        .layer(
1202            ServiceBuilder::new()
1203                .layer(Extension(server.app_state.clone()))
1204                .layer(middleware::from_fn(auth::validate_header)),
1205        )
1206        .route("/metrics", get(handle_metrics))
1207        .layer(Extension(server))
1208}
1209
1210pub async fn handle_websocket_request(
1211    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1212    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1213    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1214    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1215    Extension(server): Extension<Arc<Server>>,
1216    Extension(principal): Extension<Principal>,
1217    user_agent: Option<TypedHeader<UserAgent>>,
1218    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1219    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1220    ws: WebSocketUpgrade,
1221) -> axum::response::Response {
1222    if protocol_version != rpc::PROTOCOL_VERSION {
1223        return (
1224            StatusCode::UPGRADE_REQUIRED,
1225            "client must be upgraded".to_string(),
1226        )
1227            .into_response();
1228    }
1229
1230    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1231        return (
1232            StatusCode::UPGRADE_REQUIRED,
1233            "no version header found".to_string(),
1234        )
1235            .into_response();
1236    };
1237
1238    let release_channel = release_channel_header.map(|header| header.0.0);
1239
1240    if !version.can_collaborate() {
1241        return (
1242            StatusCode::UPGRADE_REQUIRED,
1243            "client must be upgraded".to_string(),
1244        )
1245            .into_response();
1246    }
1247
1248    let socket_address = socket_address.to_string();
1249
1250    // Acquire connection guard before WebSocket upgrade
1251    let connection_guard = match ConnectionGuard::try_acquire() {
1252        Ok(guard) => guard,
1253        Err(()) => {
1254            return (
1255                StatusCode::SERVICE_UNAVAILABLE,
1256                "Too many concurrent connections",
1257            )
1258                .into_response();
1259        }
1260    };
1261
1262    ws.on_upgrade(move |socket| {
1263        let socket = socket
1264            .map_ok(to_tungstenite_message)
1265            .err_into()
1266            .with(|message| async move { to_axum_message(message) });
1267        let connection = Connection::new(Box::pin(socket));
1268        async move {
1269            server
1270                .handle_connection(
1271                    connection,
1272                    socket_address,
1273                    principal,
1274                    version,
1275                    release_channel,
1276                    user_agent.map(|header| header.to_string()),
1277                    country_code_header.map(|header| header.to_string()),
1278                    system_id_header.map(|header| header.to_string()),
1279                    None,
1280                    Executor::Production,
1281                    Some(connection_guard),
1282                )
1283                .await;
1284        }
1285    })
1286}
1287
1288pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1289    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1290    let connections_metric = CONNECTIONS_METRIC
1291        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1292
1293    let connections = server
1294        .connection_pool
1295        .lock()
1296        .connections()
1297        .filter(|connection| !connection.admin)
1298        .count();
1299    connections_metric.set(connections as _);
1300
1301    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1302    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1303        register_int_gauge!(
1304            "shared_projects",
1305            "number of open projects with one or more guests"
1306        )
1307        .unwrap()
1308    });
1309
1310    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1311    shared_projects_metric.set(shared_projects as _);
1312
1313    let encoder = prometheus::TextEncoder::new();
1314    let metric_families = prometheus::gather();
1315    let encoded_metrics = encoder
1316        .encode_to_string(&metric_families)
1317        .map_err(|err| anyhow!("{err}"))?;
1318    Ok(encoded_metrics)
1319}
1320
1321#[instrument(err, skip(executor))]
1322async fn connection_lost(
1323    session: Session,
1324    mut teardown: watch::Receiver<bool>,
1325    executor: Executor,
1326) -> Result<()> {
1327    session.peer.disconnect(session.connection_id);
1328    session
1329        .connection_pool()
1330        .await
1331        .remove_connection(session.connection_id)?;
1332
1333    session
1334        .db()
1335        .await
1336        .connection_lost(session.connection_id)
1337        .await
1338        .trace_err();
1339
1340    futures::select_biased! {
1341        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1342
1343            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1344            leave_room_for_session(&session, session.connection_id).await.trace_err();
1345            leave_channel_buffers_for_session(&session)
1346                .await
1347                .trace_err();
1348
1349            if !session
1350                .connection_pool()
1351                .await
1352                .is_user_online(session.user_id())
1353            {
1354                let db = session.db().await;
1355                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1356                    room_updated(&room, &session.peer);
1357                }
1358            }
1359
1360            update_user_contacts(session.user_id(), &session).await?;
1361        },
1362        _ = teardown.changed().fuse() => {}
1363    }
1364
1365    Ok(())
1366}
1367
1368/// Acknowledges a ping from a client, used to keep the connection alive.
1369async fn ping(
1370    _: proto::Ping,
1371    response: Response<proto::Ping>,
1372    _session: MessageContext,
1373) -> Result<()> {
1374    response.send(proto::Ack {})?;
1375    Ok(())
1376}
1377
1378/// Creates a new room for calling (outside of channels)
1379async fn create_room(
1380    _request: proto::CreateRoom,
1381    response: Response<proto::CreateRoom>,
1382    session: MessageContext,
1383) -> Result<()> {
1384    let livekit_room = nanoid::nanoid!(30);
1385
1386    let live_kit_connection_info = util::maybe!(async {
1387        let live_kit = session.app_state.livekit_client.as_ref();
1388        let live_kit = live_kit?;
1389        let user_id = session.user_id().to_string();
1390
1391        let token = live_kit
1392            .room_token(&livekit_room, &user_id.to_string())
1393            .trace_err()?;
1394
1395        Some(proto::LiveKitConnectionInfo {
1396            server_url: live_kit.url().into(),
1397            token,
1398            can_publish: true,
1399        })
1400    })
1401    .await;
1402
1403    let room = session
1404        .db()
1405        .await
1406        .create_room(session.user_id(), session.connection_id, &livekit_room)
1407        .await?;
1408
1409    response.send(proto::CreateRoomResponse {
1410        room: Some(room.clone()),
1411        live_kit_connection_info,
1412    })?;
1413
1414    update_user_contacts(session.user_id(), &session).await?;
1415    Ok(())
1416}
1417
1418/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1419async fn join_room(
1420    request: proto::JoinRoom,
1421    response: Response<proto::JoinRoom>,
1422    session: MessageContext,
1423) -> Result<()> {
1424    let room_id = RoomId::from_proto(request.id);
1425
1426    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1427
1428    if let Some(channel_id) = channel_id {
1429        return join_channel_internal(channel_id, Box::new(response), session).await;
1430    }
1431
1432    let joined_room = {
1433        let room = session
1434            .db()
1435            .await
1436            .join_room(room_id, session.user_id(), session.connection_id)
1437            .await?;
1438        room_updated(&room.room, &session.peer);
1439        room.into_inner()
1440    };
1441
1442    for connection_id in session
1443        .connection_pool()
1444        .await
1445        .user_connection_ids(session.user_id())
1446    {
1447        session
1448            .peer
1449            .send(
1450                connection_id,
1451                proto::CallCanceled {
1452                    room_id: room_id.to_proto(),
1453                },
1454            )
1455            .trace_err();
1456    }
1457
1458    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1459    {
1460        live_kit
1461            .room_token(
1462                &joined_room.room.livekit_room,
1463                &session.user_id().to_string(),
1464            )
1465            .trace_err()
1466            .map(|token| proto::LiveKitConnectionInfo {
1467                server_url: live_kit.url().into(),
1468                token,
1469                can_publish: true,
1470            })
1471    } else {
1472        None
1473    };
1474
1475    response.send(proto::JoinRoomResponse {
1476        room: Some(joined_room.room),
1477        channel_id: None,
1478        live_kit_connection_info,
1479    })?;
1480
1481    update_user_contacts(session.user_id(), &session).await?;
1482    Ok(())
1483}
1484
1485/// Rejoin room is used to reconnect to a room after connection errors.
1486async fn rejoin_room(
1487    request: proto::RejoinRoom,
1488    response: Response<proto::RejoinRoom>,
1489    session: MessageContext,
1490) -> Result<()> {
1491    let room;
1492    let channel;
1493    {
1494        let mut rejoined_room = session
1495            .db()
1496            .await
1497            .rejoin_room(request, session.user_id(), session.connection_id)
1498            .await?;
1499
1500        response.send(proto::RejoinRoomResponse {
1501            room: Some(rejoined_room.room.clone()),
1502            reshared_projects: rejoined_room
1503                .reshared_projects
1504                .iter()
1505                .map(|project| proto::ResharedProject {
1506                    id: project.id.to_proto(),
1507                    collaborators: project
1508                        .collaborators
1509                        .iter()
1510                        .map(|collaborator| collaborator.to_proto())
1511                        .collect(),
1512                })
1513                .collect(),
1514            rejoined_projects: rejoined_room
1515                .rejoined_projects
1516                .iter()
1517                .map(|rejoined_project| rejoined_project.to_proto())
1518                .collect(),
1519        })?;
1520        room_updated(&rejoined_room.room, &session.peer);
1521
1522        for project in &rejoined_room.reshared_projects {
1523            for collaborator in &project.collaborators {
1524                session
1525                    .peer
1526                    .send(
1527                        collaborator.connection_id,
1528                        proto::UpdateProjectCollaborator {
1529                            project_id: project.id.to_proto(),
1530                            old_peer_id: Some(project.old_connection_id.into()),
1531                            new_peer_id: Some(session.connection_id.into()),
1532                        },
1533                    )
1534                    .trace_err();
1535            }
1536
1537            broadcast(
1538                Some(session.connection_id),
1539                project
1540                    .collaborators
1541                    .iter()
1542                    .map(|collaborator| collaborator.connection_id),
1543                |connection_id| {
1544                    session.peer.forward_send(
1545                        session.connection_id,
1546                        connection_id,
1547                        proto::UpdateProject {
1548                            project_id: project.id.to_proto(),
1549                            worktrees: project.worktrees.clone(),
1550                        },
1551                    )
1552                },
1553            );
1554        }
1555
1556        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1557
1558        let rejoined_room = rejoined_room.into_inner();
1559
1560        room = rejoined_room.room;
1561        channel = rejoined_room.channel;
1562    }
1563
1564    if let Some(channel) = channel {
1565        channel_updated(
1566            &channel,
1567            &room,
1568            &session.peer,
1569            &*session.connection_pool().await,
1570        );
1571    }
1572
1573    update_user_contacts(session.user_id(), &session).await?;
1574    Ok(())
1575}
1576
1577fn notify_rejoined_projects(
1578    rejoined_projects: &mut Vec<RejoinedProject>,
1579    session: &Session,
1580) -> Result<()> {
1581    for project in rejoined_projects.iter() {
1582        for collaborator in &project.collaborators {
1583            session
1584                .peer
1585                .send(
1586                    collaborator.connection_id,
1587                    proto::UpdateProjectCollaborator {
1588                        project_id: project.id.to_proto(),
1589                        old_peer_id: Some(project.old_connection_id.into()),
1590                        new_peer_id: Some(session.connection_id.into()),
1591                    },
1592                )
1593                .trace_err();
1594        }
1595    }
1596
1597    for project in rejoined_projects {
1598        for worktree in mem::take(&mut project.worktrees) {
1599            // Stream this worktree's entries.
1600            let message = proto::UpdateWorktree {
1601                project_id: project.id.to_proto(),
1602                worktree_id: worktree.id,
1603                abs_path: worktree.abs_path.clone(),
1604                root_name: worktree.root_name,
1605                updated_entries: worktree.updated_entries,
1606                removed_entries: worktree.removed_entries,
1607                scan_id: worktree.scan_id,
1608                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1609                updated_repositories: worktree.updated_repositories,
1610                removed_repositories: worktree.removed_repositories,
1611            };
1612            for update in proto::split_worktree_update(message) {
1613                session.peer.send(session.connection_id, update)?;
1614            }
1615
1616            // Stream this worktree's diagnostics.
1617            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1618            if let Some(summary) = worktree_diagnostics.next() {
1619                let message = proto::UpdateDiagnosticSummary {
1620                    project_id: project.id.to_proto(),
1621                    worktree_id: worktree.id,
1622                    summary: Some(summary),
1623                    more_summaries: worktree_diagnostics.collect(),
1624                };
1625                session.peer.send(session.connection_id, message)?;
1626            }
1627
1628            for settings_file in worktree.settings_files {
1629                session.peer.send(
1630                    session.connection_id,
1631                    proto::UpdateWorktreeSettings {
1632                        project_id: project.id.to_proto(),
1633                        worktree_id: worktree.id,
1634                        path: settings_file.path,
1635                        content: Some(settings_file.content),
1636                        kind: Some(settings_file.kind.to_proto().into()),
1637                    },
1638                )?;
1639            }
1640        }
1641
1642        for repository in mem::take(&mut project.updated_repositories) {
1643            for update in split_repository_update(repository) {
1644                session.peer.send(session.connection_id, update)?;
1645            }
1646        }
1647
1648        for id in mem::take(&mut project.removed_repositories) {
1649            session.peer.send(
1650                session.connection_id,
1651                proto::RemoveRepository {
1652                    project_id: project.id.to_proto(),
1653                    id,
1654                },
1655            )?;
1656        }
1657    }
1658
1659    Ok(())
1660}
1661
1662/// leave room disconnects from the room.
1663async fn leave_room(
1664    _: proto::LeaveRoom,
1665    response: Response<proto::LeaveRoom>,
1666    session: MessageContext,
1667) -> Result<()> {
1668    leave_room_for_session(&session, session.connection_id).await?;
1669    response.send(proto::Ack {})?;
1670    Ok(())
1671}
1672
1673/// Updates the permissions of someone else in the room.
1674async fn set_room_participant_role(
1675    request: proto::SetRoomParticipantRole,
1676    response: Response<proto::SetRoomParticipantRole>,
1677    session: MessageContext,
1678) -> Result<()> {
1679    let user_id = UserId::from_proto(request.user_id);
1680    let role = ChannelRole::from(request.role());
1681
1682    let (livekit_room, can_publish) = {
1683        let room = session
1684            .db()
1685            .await
1686            .set_room_participant_role(
1687                session.user_id(),
1688                RoomId::from_proto(request.room_id),
1689                user_id,
1690                role,
1691            )
1692            .await?;
1693
1694        let livekit_room = room.livekit_room.clone();
1695        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1696        room_updated(&room, &session.peer);
1697        (livekit_room, can_publish)
1698    };
1699
1700    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1701        live_kit
1702            .update_participant(
1703                livekit_room.clone(),
1704                request.user_id.to_string(),
1705                livekit_api::proto::ParticipantPermission {
1706                    can_subscribe: true,
1707                    can_publish,
1708                    can_publish_data: can_publish,
1709                    hidden: false,
1710                    recorder: false,
1711                },
1712            )
1713            .await
1714            .trace_err();
1715    }
1716
1717    response.send(proto::Ack {})?;
1718    Ok(())
1719}
1720
1721/// Call someone else into the current room
1722async fn call(
1723    request: proto::Call,
1724    response: Response<proto::Call>,
1725    session: MessageContext,
1726) -> Result<()> {
1727    let room_id = RoomId::from_proto(request.room_id);
1728    let calling_user_id = session.user_id();
1729    let calling_connection_id = session.connection_id;
1730    let called_user_id = UserId::from_proto(request.called_user_id);
1731    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1732    if !session
1733        .db()
1734        .await
1735        .has_contact(calling_user_id, called_user_id)
1736        .await?
1737    {
1738        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1739    }
1740
1741    let incoming_call = {
1742        let (room, incoming_call) = &mut *session
1743            .db()
1744            .await
1745            .call(
1746                room_id,
1747                calling_user_id,
1748                calling_connection_id,
1749                called_user_id,
1750                initial_project_id,
1751            )
1752            .await?;
1753        room_updated(room, &session.peer);
1754        mem::take(incoming_call)
1755    };
1756    update_user_contacts(called_user_id, &session).await?;
1757
1758    let mut calls = session
1759        .connection_pool()
1760        .await
1761        .user_connection_ids(called_user_id)
1762        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1763        .collect::<FuturesUnordered<_>>();
1764
1765    while let Some(call_response) = calls.next().await {
1766        match call_response.as_ref() {
1767            Ok(_) => {
1768                response.send(proto::Ack {})?;
1769                return Ok(());
1770            }
1771            Err(_) => {
1772                call_response.trace_err();
1773            }
1774        }
1775    }
1776
1777    {
1778        let room = session
1779            .db()
1780            .await
1781            .call_failed(room_id, called_user_id)
1782            .await?;
1783        room_updated(&room, &session.peer);
1784    }
1785    update_user_contacts(called_user_id, &session).await?;
1786
1787    Err(anyhow!("failed to ring user"))?
1788}
1789
1790/// Cancel an outgoing call.
1791async fn cancel_call(
1792    request: proto::CancelCall,
1793    response: Response<proto::CancelCall>,
1794    session: MessageContext,
1795) -> Result<()> {
1796    let called_user_id = UserId::from_proto(request.called_user_id);
1797    let room_id = RoomId::from_proto(request.room_id);
1798    {
1799        let room = session
1800            .db()
1801            .await
1802            .cancel_call(room_id, session.connection_id, called_user_id)
1803            .await?;
1804        room_updated(&room, &session.peer);
1805    }
1806
1807    for connection_id in session
1808        .connection_pool()
1809        .await
1810        .user_connection_ids(called_user_id)
1811    {
1812        session
1813            .peer
1814            .send(
1815                connection_id,
1816                proto::CallCanceled {
1817                    room_id: room_id.to_proto(),
1818                },
1819            )
1820            .trace_err();
1821    }
1822    response.send(proto::Ack {})?;
1823
1824    update_user_contacts(called_user_id, &session).await?;
1825    Ok(())
1826}
1827
1828/// Decline an incoming call.
1829async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1830    let room_id = RoomId::from_proto(message.room_id);
1831    {
1832        let room = session
1833            .db()
1834            .await
1835            .decline_call(Some(room_id), session.user_id())
1836            .await?
1837            .context("declining call")?;
1838        room_updated(&room, &session.peer);
1839    }
1840
1841    for connection_id in session
1842        .connection_pool()
1843        .await
1844        .user_connection_ids(session.user_id())
1845    {
1846        session
1847            .peer
1848            .send(
1849                connection_id,
1850                proto::CallCanceled {
1851                    room_id: room_id.to_proto(),
1852                },
1853            )
1854            .trace_err();
1855    }
1856    update_user_contacts(session.user_id(), &session).await?;
1857    Ok(())
1858}
1859
1860/// Updates other participants in the room with your current location.
1861async fn update_participant_location(
1862    request: proto::UpdateParticipantLocation,
1863    response: Response<proto::UpdateParticipantLocation>,
1864    session: MessageContext,
1865) -> Result<()> {
1866    let room_id = RoomId::from_proto(request.room_id);
1867    let location = request.location.context("invalid location")?;
1868
1869    let db = session.db().await;
1870    let room = db
1871        .update_room_participant_location(room_id, session.connection_id, location)
1872        .await?;
1873
1874    room_updated(&room, &session.peer);
1875    response.send(proto::Ack {})?;
1876    Ok(())
1877}
1878
1879/// Share a project into the room.
1880async fn share_project(
1881    request: proto::ShareProject,
1882    response: Response<proto::ShareProject>,
1883    session: MessageContext,
1884) -> Result<()> {
1885    let (project_id, room) = &*session
1886        .db()
1887        .await
1888        .share_project(
1889            RoomId::from_proto(request.room_id),
1890            session.connection_id,
1891            &request.worktrees,
1892            request.is_ssh_project,
1893        )
1894        .await?;
1895    response.send(proto::ShareProjectResponse {
1896        project_id: project_id.to_proto(),
1897    })?;
1898    room_updated(room, &session.peer);
1899
1900    Ok(())
1901}
1902
1903/// Unshare a project from the room.
1904async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1905    let project_id = ProjectId::from_proto(message.project_id);
1906    unshare_project_internal(project_id, session.connection_id, &session).await
1907}
1908
1909async fn unshare_project_internal(
1910    project_id: ProjectId,
1911    connection_id: ConnectionId,
1912    session: &Session,
1913) -> Result<()> {
1914    let delete = {
1915        let room_guard = session
1916            .db()
1917            .await
1918            .unshare_project(project_id, connection_id)
1919            .await?;
1920
1921        let (delete, room, guest_connection_ids) = &*room_guard;
1922
1923        let message = proto::UnshareProject {
1924            project_id: project_id.to_proto(),
1925        };
1926
1927        broadcast(
1928            Some(connection_id),
1929            guest_connection_ids.iter().copied(),
1930            |conn_id| session.peer.send(conn_id, message.clone()),
1931        );
1932        if let Some(room) = room {
1933            room_updated(room, &session.peer);
1934        }
1935
1936        *delete
1937    };
1938
1939    if delete {
1940        let db = session.db().await;
1941        db.delete_project(project_id).await?;
1942    }
1943
1944    Ok(())
1945}
1946
1947/// Join someone elses shared project.
1948async fn join_project(
1949    request: proto::JoinProject,
1950    response: Response<proto::JoinProject>,
1951    session: MessageContext,
1952) -> Result<()> {
1953    let project_id = ProjectId::from_proto(request.project_id);
1954
1955    tracing::info!(%project_id, "join project");
1956
1957    let db = session.db().await;
1958    let (project, replica_id) = &mut *db
1959        .join_project(
1960            project_id,
1961            session.connection_id,
1962            session.user_id(),
1963            request.committer_name.clone(),
1964            request.committer_email.clone(),
1965        )
1966        .await?;
1967    drop(db);
1968    tracing::info!(%project_id, "join remote project");
1969    let collaborators = project
1970        .collaborators
1971        .iter()
1972        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1973        .map(|collaborator| collaborator.to_proto())
1974        .collect::<Vec<_>>();
1975    let project_id = project.id;
1976    let guest_user_id = session.user_id();
1977
1978    let worktrees = project
1979        .worktrees
1980        .iter()
1981        .map(|(id, worktree)| proto::WorktreeMetadata {
1982            id: *id,
1983            root_name: worktree.root_name.clone(),
1984            visible: worktree.visible,
1985            abs_path: worktree.abs_path.clone(),
1986        })
1987        .collect::<Vec<_>>();
1988
1989    let add_project_collaborator = proto::AddProjectCollaborator {
1990        project_id: project_id.to_proto(),
1991        collaborator: Some(proto::Collaborator {
1992            peer_id: Some(session.connection_id.into()),
1993            replica_id: replica_id.0 as u32,
1994            user_id: guest_user_id.to_proto(),
1995            is_host: false,
1996            committer_name: request.committer_name.clone(),
1997            committer_email: request.committer_email.clone(),
1998        }),
1999    };
2000
2001    for collaborator in &collaborators {
2002        session
2003            .peer
2004            .send(
2005                collaborator.peer_id.unwrap().into(),
2006                add_project_collaborator.clone(),
2007            )
2008            .trace_err();
2009    }
2010
2011    // First, we send the metadata associated with each worktree.
2012    let (language_servers, language_server_capabilities) = project
2013        .language_servers
2014        .clone()
2015        .into_iter()
2016        .map(|server| (server.server, server.capabilities))
2017        .unzip();
2018    response.send(proto::JoinProjectResponse {
2019        project_id: project.id.0 as u64,
2020        worktrees: worktrees.clone(),
2021        replica_id: replica_id.0 as u32,
2022        collaborators: collaborators.clone(),
2023        language_servers,
2024        language_server_capabilities,
2025        role: project.role.into(),
2026    })?;
2027
2028    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2029        // Stream this worktree's entries.
2030        let message = proto::UpdateWorktree {
2031            project_id: project_id.to_proto(),
2032            worktree_id,
2033            abs_path: worktree.abs_path.clone(),
2034            root_name: worktree.root_name,
2035            updated_entries: worktree.entries,
2036            removed_entries: Default::default(),
2037            scan_id: worktree.scan_id,
2038            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2039            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2040            removed_repositories: Default::default(),
2041        };
2042        for update in proto::split_worktree_update(message) {
2043            session.peer.send(session.connection_id, update.clone())?;
2044        }
2045
2046        // Stream this worktree's diagnostics.
2047        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2048        if let Some(summary) = worktree_diagnostics.next() {
2049            let message = proto::UpdateDiagnosticSummary {
2050                project_id: project.id.to_proto(),
2051                worktree_id: worktree.id,
2052                summary: Some(summary),
2053                more_summaries: worktree_diagnostics.collect(),
2054            };
2055            session.peer.send(session.connection_id, message)?;
2056        }
2057
2058        for settings_file in worktree.settings_files {
2059            session.peer.send(
2060                session.connection_id,
2061                proto::UpdateWorktreeSettings {
2062                    project_id: project_id.to_proto(),
2063                    worktree_id: worktree.id,
2064                    path: settings_file.path,
2065                    content: Some(settings_file.content),
2066                    kind: Some(settings_file.kind.to_proto() as i32),
2067                },
2068            )?;
2069        }
2070    }
2071
2072    for repository in mem::take(&mut project.repositories) {
2073        for update in split_repository_update(repository) {
2074            session.peer.send(session.connection_id, update)?;
2075        }
2076    }
2077
2078    for language_server in &project.language_servers {
2079        session.peer.send(
2080            session.connection_id,
2081            proto::UpdateLanguageServer {
2082                project_id: project_id.to_proto(),
2083                server_name: Some(language_server.server.name.clone()),
2084                language_server_id: language_server.server.id,
2085                variant: Some(
2086                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2087                        proto::LspDiskBasedDiagnosticsUpdated {},
2088                    ),
2089                ),
2090            },
2091        )?;
2092    }
2093
2094    Ok(())
2095}
2096
2097/// Leave someone elses shared project.
2098async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2099    let sender_id = session.connection_id;
2100    let project_id = ProjectId::from_proto(request.project_id);
2101    let db = session.db().await;
2102
2103    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2104    tracing::info!(
2105        %project_id,
2106        "leave project"
2107    );
2108
2109    project_left(project, &session);
2110    if let Some(room) = room {
2111        room_updated(room, &session.peer);
2112    }
2113
2114    Ok(())
2115}
2116
2117/// Updates other participants with changes to the project
2118async fn update_project(
2119    request: proto::UpdateProject,
2120    response: Response<proto::UpdateProject>,
2121    session: MessageContext,
2122) -> Result<()> {
2123    let project_id = ProjectId::from_proto(request.project_id);
2124    let (room, guest_connection_ids) = &*session
2125        .db()
2126        .await
2127        .update_project(project_id, session.connection_id, &request.worktrees)
2128        .await?;
2129    broadcast(
2130        Some(session.connection_id),
2131        guest_connection_ids.iter().copied(),
2132        |connection_id| {
2133            session
2134                .peer
2135                .forward_send(session.connection_id, connection_id, request.clone())
2136        },
2137    );
2138    if let Some(room) = room {
2139        room_updated(room, &session.peer);
2140    }
2141    response.send(proto::Ack {})?;
2142
2143    Ok(())
2144}
2145
2146/// Updates other participants with changes to the worktree
2147async fn update_worktree(
2148    request: proto::UpdateWorktree,
2149    response: Response<proto::UpdateWorktree>,
2150    session: MessageContext,
2151) -> Result<()> {
2152    let guest_connection_ids = session
2153        .db()
2154        .await
2155        .update_worktree(&request, session.connection_id)
2156        .await?;
2157
2158    broadcast(
2159        Some(session.connection_id),
2160        guest_connection_ids.iter().copied(),
2161        |connection_id| {
2162            session
2163                .peer
2164                .forward_send(session.connection_id, connection_id, request.clone())
2165        },
2166    );
2167    response.send(proto::Ack {})?;
2168    Ok(())
2169}
2170
2171async fn update_repository(
2172    request: proto::UpdateRepository,
2173    response: Response<proto::UpdateRepository>,
2174    session: MessageContext,
2175) -> Result<()> {
2176    let guest_connection_ids = session
2177        .db()
2178        .await
2179        .update_repository(&request, session.connection_id)
2180        .await?;
2181
2182    broadcast(
2183        Some(session.connection_id),
2184        guest_connection_ids.iter().copied(),
2185        |connection_id| {
2186            session
2187                .peer
2188                .forward_send(session.connection_id, connection_id, request.clone())
2189        },
2190    );
2191    response.send(proto::Ack {})?;
2192    Ok(())
2193}
2194
2195async fn remove_repository(
2196    request: proto::RemoveRepository,
2197    response: Response<proto::RemoveRepository>,
2198    session: MessageContext,
2199) -> Result<()> {
2200    let guest_connection_ids = session
2201        .db()
2202        .await
2203        .remove_repository(&request, session.connection_id)
2204        .await?;
2205
2206    broadcast(
2207        Some(session.connection_id),
2208        guest_connection_ids.iter().copied(),
2209        |connection_id| {
2210            session
2211                .peer
2212                .forward_send(session.connection_id, connection_id, request.clone())
2213        },
2214    );
2215    response.send(proto::Ack {})?;
2216    Ok(())
2217}
2218
2219/// Updates other participants with changes to the diagnostics
2220async fn update_diagnostic_summary(
2221    message: proto::UpdateDiagnosticSummary,
2222    session: MessageContext,
2223) -> Result<()> {
2224    let guest_connection_ids = session
2225        .db()
2226        .await
2227        .update_diagnostic_summary(&message, session.connection_id)
2228        .await?;
2229
2230    broadcast(
2231        Some(session.connection_id),
2232        guest_connection_ids.iter().copied(),
2233        |connection_id| {
2234            session
2235                .peer
2236                .forward_send(session.connection_id, connection_id, message.clone())
2237        },
2238    );
2239
2240    Ok(())
2241}
2242
2243/// Updates other participants with changes to the worktree settings
2244async fn update_worktree_settings(
2245    message: proto::UpdateWorktreeSettings,
2246    session: MessageContext,
2247) -> Result<()> {
2248    let guest_connection_ids = session
2249        .db()
2250        .await
2251        .update_worktree_settings(&message, session.connection_id)
2252        .await?;
2253
2254    broadcast(
2255        Some(session.connection_id),
2256        guest_connection_ids.iter().copied(),
2257        |connection_id| {
2258            session
2259                .peer
2260                .forward_send(session.connection_id, connection_id, message.clone())
2261        },
2262    );
2263
2264    Ok(())
2265}
2266
2267/// Notify other participants that a language server has started.
2268async fn start_language_server(
2269    request: proto::StartLanguageServer,
2270    session: MessageContext,
2271) -> Result<()> {
2272    let guest_connection_ids = session
2273        .db()
2274        .await
2275        .start_language_server(&request, session.connection_id)
2276        .await?;
2277
2278    broadcast(
2279        Some(session.connection_id),
2280        guest_connection_ids.iter().copied(),
2281        |connection_id| {
2282            session
2283                .peer
2284                .forward_send(session.connection_id, connection_id, request.clone())
2285        },
2286    );
2287    Ok(())
2288}
2289
2290/// Notify other participants that a language server has changed.
2291async fn update_language_server(
2292    request: proto::UpdateLanguageServer,
2293    session: MessageContext,
2294) -> Result<()> {
2295    let project_id = ProjectId::from_proto(request.project_id);
2296    let db = session.db().await;
2297
2298    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2299    {
2300        if let Some(capabilities) = update.capabilities.clone() {
2301            db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2302                .await?;
2303        }
2304    }
2305
2306    let project_connection_ids = db
2307        .project_connection_ids(project_id, session.connection_id, true)
2308        .await?;
2309    broadcast(
2310        Some(session.connection_id),
2311        project_connection_ids.iter().copied(),
2312        |connection_id| {
2313            session
2314                .peer
2315                .forward_send(session.connection_id, connection_id, request.clone())
2316        },
2317    );
2318    Ok(())
2319}
2320
2321/// forward a project request to the host. These requests should be read only
2322/// as guests are allowed to send them.
2323async fn forward_read_only_project_request<T>(
2324    request: T,
2325    response: Response<T>,
2326    session: MessageContext,
2327) -> Result<()>
2328where
2329    T: EntityMessage + RequestMessage,
2330{
2331    let project_id = ProjectId::from_proto(request.remote_entity_id());
2332    let host_connection_id = session
2333        .db()
2334        .await
2335        .host_for_read_only_project_request(project_id, session.connection_id)
2336        .await?;
2337    let payload = session.forward_request(host_connection_id, request).await?;
2338    response.send(payload)?;
2339    Ok(())
2340}
2341
2342/// forward a project request to the host. These requests are disallowed
2343/// for guests.
2344async fn forward_mutating_project_request<T>(
2345    request: T,
2346    response: Response<T>,
2347    session: MessageContext,
2348) -> Result<()>
2349where
2350    T: EntityMessage + RequestMessage,
2351{
2352    let project_id = ProjectId::from_proto(request.remote_entity_id());
2353
2354    let host_connection_id = session
2355        .db()
2356        .await
2357        .host_for_mutating_project_request(project_id, session.connection_id)
2358        .await?;
2359    let payload = session.forward_request(host_connection_id, request).await?;
2360    response.send(payload)?;
2361    Ok(())
2362}
2363
2364async fn multi_lsp_query(
2365    request: MultiLspQuery,
2366    response: Response<MultiLspQuery>,
2367    session: MessageContext,
2368) -> Result<()> {
2369    tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2370    tracing::info!("multi_lsp_query message received");
2371    forward_mutating_project_request(request, response, session).await
2372}
2373
2374/// Notify other participants that a new buffer has been created
2375async fn create_buffer_for_peer(
2376    request: proto::CreateBufferForPeer,
2377    session: MessageContext,
2378) -> Result<()> {
2379    session
2380        .db()
2381        .await
2382        .check_user_is_project_host(
2383            ProjectId::from_proto(request.project_id),
2384            session.connection_id,
2385        )
2386        .await?;
2387    let peer_id = request.peer_id.context("invalid peer id")?;
2388    session
2389        .peer
2390        .forward_send(session.connection_id, peer_id.into(), request)?;
2391    Ok(())
2392}
2393
2394/// Notify other participants that a buffer has been updated. This is
2395/// allowed for guests as long as the update is limited to selections.
2396async fn update_buffer(
2397    request: proto::UpdateBuffer,
2398    response: Response<proto::UpdateBuffer>,
2399    session: MessageContext,
2400) -> Result<()> {
2401    let project_id = ProjectId::from_proto(request.project_id);
2402    let mut capability = Capability::ReadOnly;
2403
2404    for op in request.operations.iter() {
2405        match op.variant {
2406            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2407            Some(_) => capability = Capability::ReadWrite,
2408        }
2409    }
2410
2411    let host = {
2412        let guard = session
2413            .db()
2414            .await
2415            .connections_for_buffer_update(project_id, session.connection_id, capability)
2416            .await?;
2417
2418        let (host, guests) = &*guard;
2419
2420        broadcast(
2421            Some(session.connection_id),
2422            guests.clone(),
2423            |connection_id| {
2424                session
2425                    .peer
2426                    .forward_send(session.connection_id, connection_id, request.clone())
2427            },
2428        );
2429
2430        *host
2431    };
2432
2433    if host != session.connection_id {
2434        session.forward_request(host, request.clone()).await?;
2435    }
2436
2437    response.send(proto::Ack {})?;
2438    Ok(())
2439}
2440
2441async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2442    let project_id = ProjectId::from_proto(message.project_id);
2443
2444    let operation = message.operation.as_ref().context("invalid operation")?;
2445    let capability = match operation.variant.as_ref() {
2446        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2447            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2448                match buffer_op.variant {
2449                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2450                        Capability::ReadOnly
2451                    }
2452                    _ => Capability::ReadWrite,
2453                }
2454            } else {
2455                Capability::ReadWrite
2456            }
2457        }
2458        Some(_) => Capability::ReadWrite,
2459        None => Capability::ReadOnly,
2460    };
2461
2462    let guard = session
2463        .db()
2464        .await
2465        .connections_for_buffer_update(project_id, session.connection_id, capability)
2466        .await?;
2467
2468    let (host, guests) = &*guard;
2469
2470    broadcast(
2471        Some(session.connection_id),
2472        guests.iter().chain([host]).copied(),
2473        |connection_id| {
2474            session
2475                .peer
2476                .forward_send(session.connection_id, connection_id, message.clone())
2477        },
2478    );
2479
2480    Ok(())
2481}
2482
2483/// Notify other participants that a project has been updated.
2484async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2485    request: T,
2486    session: MessageContext,
2487) -> Result<()> {
2488    let project_id = ProjectId::from_proto(request.remote_entity_id());
2489    let project_connection_ids = session
2490        .db()
2491        .await
2492        .project_connection_ids(project_id, session.connection_id, false)
2493        .await?;
2494
2495    broadcast(
2496        Some(session.connection_id),
2497        project_connection_ids.iter().copied(),
2498        |connection_id| {
2499            session
2500                .peer
2501                .forward_send(session.connection_id, connection_id, request.clone())
2502        },
2503    );
2504    Ok(())
2505}
2506
2507/// Start following another user in a call.
2508async fn follow(
2509    request: proto::Follow,
2510    response: Response<proto::Follow>,
2511    session: MessageContext,
2512) -> Result<()> {
2513    let room_id = RoomId::from_proto(request.room_id);
2514    let project_id = request.project_id.map(ProjectId::from_proto);
2515    let leader_id = request.leader_id.context("invalid leader id")?.into();
2516    let follower_id = session.connection_id;
2517
2518    session
2519        .db()
2520        .await
2521        .check_room_participants(room_id, leader_id, session.connection_id)
2522        .await?;
2523
2524    let response_payload = session.forward_request(leader_id, request).await?;
2525    response.send(response_payload)?;
2526
2527    if let Some(project_id) = project_id {
2528        let room = session
2529            .db()
2530            .await
2531            .follow(room_id, project_id, leader_id, follower_id)
2532            .await?;
2533        room_updated(&room, &session.peer);
2534    }
2535
2536    Ok(())
2537}
2538
2539/// Stop following another user in a call.
2540async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2541    let room_id = RoomId::from_proto(request.room_id);
2542    let project_id = request.project_id.map(ProjectId::from_proto);
2543    let leader_id = request.leader_id.context("invalid leader id")?.into();
2544    let follower_id = session.connection_id;
2545
2546    session
2547        .db()
2548        .await
2549        .check_room_participants(room_id, leader_id, session.connection_id)
2550        .await?;
2551
2552    session
2553        .peer
2554        .forward_send(session.connection_id, leader_id, request)?;
2555
2556    if let Some(project_id) = project_id {
2557        let room = session
2558            .db()
2559            .await
2560            .unfollow(room_id, project_id, leader_id, follower_id)
2561            .await?;
2562        room_updated(&room, &session.peer);
2563    }
2564
2565    Ok(())
2566}
2567
2568/// Notify everyone following you of your current location.
2569async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2570    let room_id = RoomId::from_proto(request.room_id);
2571    let database = session.db.lock().await;
2572
2573    let connection_ids = if let Some(project_id) = request.project_id {
2574        let project_id = ProjectId::from_proto(project_id);
2575        database
2576            .project_connection_ids(project_id, session.connection_id, true)
2577            .await?
2578    } else {
2579        database
2580            .room_connection_ids(room_id, session.connection_id)
2581            .await?
2582    };
2583
2584    // For now, don't send view update messages back to that view's current leader.
2585    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2586        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2587        _ => None,
2588    });
2589
2590    for connection_id in connection_ids.iter().cloned() {
2591        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2592            session
2593                .peer
2594                .forward_send(session.connection_id, connection_id, request.clone())?;
2595        }
2596    }
2597    Ok(())
2598}
2599
2600/// Get public data about users.
2601async fn get_users(
2602    request: proto::GetUsers,
2603    response: Response<proto::GetUsers>,
2604    session: MessageContext,
2605) -> Result<()> {
2606    let user_ids = request
2607        .user_ids
2608        .into_iter()
2609        .map(UserId::from_proto)
2610        .collect();
2611    let users = session
2612        .db()
2613        .await
2614        .get_users_by_ids(user_ids)
2615        .await?
2616        .into_iter()
2617        .map(|user| proto::User {
2618            id: user.id.to_proto(),
2619            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2620            github_login: user.github_login,
2621            name: user.name,
2622        })
2623        .collect();
2624    response.send(proto::UsersResponse { users })?;
2625    Ok(())
2626}
2627
2628/// Search for users (to invite) buy Github login
2629async fn fuzzy_search_users(
2630    request: proto::FuzzySearchUsers,
2631    response: Response<proto::FuzzySearchUsers>,
2632    session: MessageContext,
2633) -> Result<()> {
2634    let query = request.query;
2635    let users = match query.len() {
2636        0 => vec![],
2637        1 | 2 => session
2638            .db()
2639            .await
2640            .get_user_by_github_login(&query)
2641            .await?
2642            .into_iter()
2643            .collect(),
2644        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2645    };
2646    let users = users
2647        .into_iter()
2648        .filter(|user| user.id != session.user_id())
2649        .map(|user| proto::User {
2650            id: user.id.to_proto(),
2651            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2652            github_login: user.github_login,
2653            name: user.name,
2654        })
2655        .collect();
2656    response.send(proto::UsersResponse { users })?;
2657    Ok(())
2658}
2659
2660/// Send a contact request to another user.
2661async fn request_contact(
2662    request: proto::RequestContact,
2663    response: Response<proto::RequestContact>,
2664    session: MessageContext,
2665) -> Result<()> {
2666    let requester_id = session.user_id();
2667    let responder_id = UserId::from_proto(request.responder_id);
2668    if requester_id == responder_id {
2669        return Err(anyhow!("cannot add yourself as a contact"))?;
2670    }
2671
2672    let notifications = session
2673        .db()
2674        .await
2675        .send_contact_request(requester_id, responder_id)
2676        .await?;
2677
2678    // Update outgoing contact requests of requester
2679    let mut update = proto::UpdateContacts::default();
2680    update.outgoing_requests.push(responder_id.to_proto());
2681    for connection_id in session
2682        .connection_pool()
2683        .await
2684        .user_connection_ids(requester_id)
2685    {
2686        session.peer.send(connection_id, update.clone())?;
2687    }
2688
2689    // Update incoming contact requests of responder
2690    let mut update = proto::UpdateContacts::default();
2691    update
2692        .incoming_requests
2693        .push(proto::IncomingContactRequest {
2694            requester_id: requester_id.to_proto(),
2695        });
2696    let connection_pool = session.connection_pool().await;
2697    for connection_id in connection_pool.user_connection_ids(responder_id) {
2698        session.peer.send(connection_id, update.clone())?;
2699    }
2700
2701    send_notifications(&connection_pool, &session.peer, notifications);
2702
2703    response.send(proto::Ack {})?;
2704    Ok(())
2705}
2706
2707/// Accept or decline a contact request
2708async fn respond_to_contact_request(
2709    request: proto::RespondToContactRequest,
2710    response: Response<proto::RespondToContactRequest>,
2711    session: MessageContext,
2712) -> Result<()> {
2713    let responder_id = session.user_id();
2714    let requester_id = UserId::from_proto(request.requester_id);
2715    let db = session.db().await;
2716    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2717        db.dismiss_contact_notification(responder_id, requester_id)
2718            .await?;
2719    } else {
2720        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2721
2722        let notifications = db
2723            .respond_to_contact_request(responder_id, requester_id, accept)
2724            .await?;
2725        let requester_busy = db.is_user_busy(requester_id).await?;
2726        let responder_busy = db.is_user_busy(responder_id).await?;
2727
2728        let pool = session.connection_pool().await;
2729        // Update responder with new contact
2730        let mut update = proto::UpdateContacts::default();
2731        if accept {
2732            update
2733                .contacts
2734                .push(contact_for_user(requester_id, requester_busy, &pool));
2735        }
2736        update
2737            .remove_incoming_requests
2738            .push(requester_id.to_proto());
2739        for connection_id in pool.user_connection_ids(responder_id) {
2740            session.peer.send(connection_id, update.clone())?;
2741        }
2742
2743        // Update requester with new contact
2744        let mut update = proto::UpdateContacts::default();
2745        if accept {
2746            update
2747                .contacts
2748                .push(contact_for_user(responder_id, responder_busy, &pool));
2749        }
2750        update
2751            .remove_outgoing_requests
2752            .push(responder_id.to_proto());
2753
2754        for connection_id in pool.user_connection_ids(requester_id) {
2755            session.peer.send(connection_id, update.clone())?;
2756        }
2757
2758        send_notifications(&pool, &session.peer, notifications);
2759    }
2760
2761    response.send(proto::Ack {})?;
2762    Ok(())
2763}
2764
2765/// Remove a contact.
2766async fn remove_contact(
2767    request: proto::RemoveContact,
2768    response: Response<proto::RemoveContact>,
2769    session: MessageContext,
2770) -> Result<()> {
2771    let requester_id = session.user_id();
2772    let responder_id = UserId::from_proto(request.user_id);
2773    let db = session.db().await;
2774    let (contact_accepted, deleted_notification_id) =
2775        db.remove_contact(requester_id, responder_id).await?;
2776
2777    let pool = session.connection_pool().await;
2778    // Update outgoing contact requests of requester
2779    let mut update = proto::UpdateContacts::default();
2780    if contact_accepted {
2781        update.remove_contacts.push(responder_id.to_proto());
2782    } else {
2783        update
2784            .remove_outgoing_requests
2785            .push(responder_id.to_proto());
2786    }
2787    for connection_id in pool.user_connection_ids(requester_id) {
2788        session.peer.send(connection_id, update.clone())?;
2789    }
2790
2791    // Update incoming contact requests of responder
2792    let mut update = proto::UpdateContacts::default();
2793    if contact_accepted {
2794        update.remove_contacts.push(requester_id.to_proto());
2795    } else {
2796        update
2797            .remove_incoming_requests
2798            .push(requester_id.to_proto());
2799    }
2800    for connection_id in pool.user_connection_ids(responder_id) {
2801        session.peer.send(connection_id, update.clone())?;
2802        if let Some(notification_id) = deleted_notification_id {
2803            session.peer.send(
2804                connection_id,
2805                proto::DeleteNotification {
2806                    notification_id: notification_id.to_proto(),
2807                },
2808            )?;
2809        }
2810    }
2811
2812    response.send(proto::Ack {})?;
2813    Ok(())
2814}
2815
2816fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2817    version.0.minor() < 139
2818}
2819
2820async fn subscribe_to_channels(
2821    _: proto::SubscribeToChannels,
2822    session: MessageContext,
2823) -> Result<()> {
2824    subscribe_user_to_channels(session.user_id(), &session).await?;
2825    Ok(())
2826}
2827
2828async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2829    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2830    let mut pool = session.connection_pool().await;
2831    for membership in &channels_for_user.channel_memberships {
2832        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2833    }
2834    session.peer.send(
2835        session.connection_id,
2836        build_update_user_channels(&channels_for_user),
2837    )?;
2838    session.peer.send(
2839        session.connection_id,
2840        build_channels_update(channels_for_user),
2841    )?;
2842    Ok(())
2843}
2844
2845/// Creates a new channel.
2846async fn create_channel(
2847    request: proto::CreateChannel,
2848    response: Response<proto::CreateChannel>,
2849    session: MessageContext,
2850) -> Result<()> {
2851    let db = session.db().await;
2852
2853    let parent_id = request.parent_id.map(ChannelId::from_proto);
2854    let (channel, membership) = db
2855        .create_channel(&request.name, parent_id, session.user_id())
2856        .await?;
2857
2858    let root_id = channel.root_id();
2859    let channel = Channel::from_model(channel);
2860
2861    response.send(proto::CreateChannelResponse {
2862        channel: Some(channel.to_proto()),
2863        parent_id: request.parent_id,
2864    })?;
2865
2866    let mut connection_pool = session.connection_pool().await;
2867    if let Some(membership) = membership {
2868        connection_pool.subscribe_to_channel(
2869            membership.user_id,
2870            membership.channel_id,
2871            membership.role,
2872        );
2873        let update = proto::UpdateUserChannels {
2874            channel_memberships: vec![proto::ChannelMembership {
2875                channel_id: membership.channel_id.to_proto(),
2876                role: membership.role.into(),
2877            }],
2878            ..Default::default()
2879        };
2880        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2881            session.peer.send(connection_id, update.clone())?;
2882        }
2883    }
2884
2885    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2886        if !role.can_see_channel(channel.visibility) {
2887            continue;
2888        }
2889
2890        let update = proto::UpdateChannels {
2891            channels: vec![channel.to_proto()],
2892            ..Default::default()
2893        };
2894        session.peer.send(connection_id, update.clone())?;
2895    }
2896
2897    Ok(())
2898}
2899
2900/// Delete a channel
2901async fn delete_channel(
2902    request: proto::DeleteChannel,
2903    response: Response<proto::DeleteChannel>,
2904    session: MessageContext,
2905) -> Result<()> {
2906    let db = session.db().await;
2907
2908    let channel_id = request.channel_id;
2909    let (root_channel, removed_channels) = db
2910        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2911        .await?;
2912    response.send(proto::Ack {})?;
2913
2914    // Notify members of removed channels
2915    let mut update = proto::UpdateChannels::default();
2916    update
2917        .delete_channels
2918        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2919
2920    let connection_pool = session.connection_pool().await;
2921    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2922        session.peer.send(connection_id, update.clone())?;
2923    }
2924
2925    Ok(())
2926}
2927
2928/// Invite someone to join a channel.
2929async fn invite_channel_member(
2930    request: proto::InviteChannelMember,
2931    response: Response<proto::InviteChannelMember>,
2932    session: MessageContext,
2933) -> Result<()> {
2934    let db = session.db().await;
2935    let channel_id = ChannelId::from_proto(request.channel_id);
2936    let invitee_id = UserId::from_proto(request.user_id);
2937    let InviteMemberResult {
2938        channel,
2939        notifications,
2940    } = db
2941        .invite_channel_member(
2942            channel_id,
2943            invitee_id,
2944            session.user_id(),
2945            request.role().into(),
2946        )
2947        .await?;
2948
2949    let update = proto::UpdateChannels {
2950        channel_invitations: vec![channel.to_proto()],
2951        ..Default::default()
2952    };
2953
2954    let connection_pool = session.connection_pool().await;
2955    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2956        session.peer.send(connection_id, update.clone())?;
2957    }
2958
2959    send_notifications(&connection_pool, &session.peer, notifications);
2960
2961    response.send(proto::Ack {})?;
2962    Ok(())
2963}
2964
2965/// remove someone from a channel
2966async fn remove_channel_member(
2967    request: proto::RemoveChannelMember,
2968    response: Response<proto::RemoveChannelMember>,
2969    session: MessageContext,
2970) -> Result<()> {
2971    let db = session.db().await;
2972    let channel_id = ChannelId::from_proto(request.channel_id);
2973    let member_id = UserId::from_proto(request.user_id);
2974
2975    let RemoveChannelMemberResult {
2976        membership_update,
2977        notification_id,
2978    } = db
2979        .remove_channel_member(channel_id, member_id, session.user_id())
2980        .await?;
2981
2982    let mut connection_pool = session.connection_pool().await;
2983    notify_membership_updated(
2984        &mut connection_pool,
2985        membership_update,
2986        member_id,
2987        &session.peer,
2988    );
2989    for connection_id in connection_pool.user_connection_ids(member_id) {
2990        if let Some(notification_id) = notification_id {
2991            session
2992                .peer
2993                .send(
2994                    connection_id,
2995                    proto::DeleteNotification {
2996                        notification_id: notification_id.to_proto(),
2997                    },
2998                )
2999                .trace_err();
3000        }
3001    }
3002
3003    response.send(proto::Ack {})?;
3004    Ok(())
3005}
3006
3007/// Toggle the channel between public and private.
3008/// Care is taken to maintain the invariant that public channels only descend from public channels,
3009/// (though members-only channels can appear at any point in the hierarchy).
3010async fn set_channel_visibility(
3011    request: proto::SetChannelVisibility,
3012    response: Response<proto::SetChannelVisibility>,
3013    session: MessageContext,
3014) -> Result<()> {
3015    let db = session.db().await;
3016    let channel_id = ChannelId::from_proto(request.channel_id);
3017    let visibility = request.visibility().into();
3018
3019    let channel_model = db
3020        .set_channel_visibility(channel_id, visibility, session.user_id())
3021        .await?;
3022    let root_id = channel_model.root_id();
3023    let channel = Channel::from_model(channel_model);
3024
3025    let mut connection_pool = session.connection_pool().await;
3026    for (user_id, role) in connection_pool
3027        .channel_user_ids(root_id)
3028        .collect::<Vec<_>>()
3029        .into_iter()
3030    {
3031        let update = if role.can_see_channel(channel.visibility) {
3032            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3033            proto::UpdateChannels {
3034                channels: vec![channel.to_proto()],
3035                ..Default::default()
3036            }
3037        } else {
3038            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3039            proto::UpdateChannels {
3040                delete_channels: vec![channel.id.to_proto()],
3041                ..Default::default()
3042            }
3043        };
3044
3045        for connection_id in connection_pool.user_connection_ids(user_id) {
3046            session.peer.send(connection_id, update.clone())?;
3047        }
3048    }
3049
3050    response.send(proto::Ack {})?;
3051    Ok(())
3052}
3053
3054/// Alter the role for a user in the channel.
3055async fn set_channel_member_role(
3056    request: proto::SetChannelMemberRole,
3057    response: Response<proto::SetChannelMemberRole>,
3058    session: MessageContext,
3059) -> Result<()> {
3060    let db = session.db().await;
3061    let channel_id = ChannelId::from_proto(request.channel_id);
3062    let member_id = UserId::from_proto(request.user_id);
3063    let result = db
3064        .set_channel_member_role(
3065            channel_id,
3066            session.user_id(),
3067            member_id,
3068            request.role().into(),
3069        )
3070        .await?;
3071
3072    match result {
3073        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3074            let mut connection_pool = session.connection_pool().await;
3075            notify_membership_updated(
3076                &mut connection_pool,
3077                membership_update,
3078                member_id,
3079                &session.peer,
3080            )
3081        }
3082        db::SetMemberRoleResult::InviteUpdated(channel) => {
3083            let update = proto::UpdateChannels {
3084                channel_invitations: vec![channel.to_proto()],
3085                ..Default::default()
3086            };
3087
3088            for connection_id in session
3089                .connection_pool()
3090                .await
3091                .user_connection_ids(member_id)
3092            {
3093                session.peer.send(connection_id, update.clone())?;
3094            }
3095        }
3096    }
3097
3098    response.send(proto::Ack {})?;
3099    Ok(())
3100}
3101
3102/// Change the name of a channel
3103async fn rename_channel(
3104    request: proto::RenameChannel,
3105    response: Response<proto::RenameChannel>,
3106    session: MessageContext,
3107) -> Result<()> {
3108    let db = session.db().await;
3109    let channel_id = ChannelId::from_proto(request.channel_id);
3110    let channel_model = db
3111        .rename_channel(channel_id, session.user_id(), &request.name)
3112        .await?;
3113    let root_id = channel_model.root_id();
3114    let channel = Channel::from_model(channel_model);
3115
3116    response.send(proto::RenameChannelResponse {
3117        channel: Some(channel.to_proto()),
3118    })?;
3119
3120    let connection_pool = session.connection_pool().await;
3121    let update = proto::UpdateChannels {
3122        channels: vec![channel.to_proto()],
3123        ..Default::default()
3124    };
3125    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3126        if role.can_see_channel(channel.visibility) {
3127            session.peer.send(connection_id, update.clone())?;
3128        }
3129    }
3130
3131    Ok(())
3132}
3133
3134/// Move a channel to a new parent.
3135async fn move_channel(
3136    request: proto::MoveChannel,
3137    response: Response<proto::MoveChannel>,
3138    session: MessageContext,
3139) -> Result<()> {
3140    let channel_id = ChannelId::from_proto(request.channel_id);
3141    let to = ChannelId::from_proto(request.to);
3142
3143    let (root_id, channels) = session
3144        .db()
3145        .await
3146        .move_channel(channel_id, to, session.user_id())
3147        .await?;
3148
3149    let connection_pool = session.connection_pool().await;
3150    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3151        let channels = channels
3152            .iter()
3153            .filter_map(|channel| {
3154                if role.can_see_channel(channel.visibility) {
3155                    Some(channel.to_proto())
3156                } else {
3157                    None
3158                }
3159            })
3160            .collect::<Vec<_>>();
3161        if channels.is_empty() {
3162            continue;
3163        }
3164
3165        let update = proto::UpdateChannels {
3166            channels,
3167            ..Default::default()
3168        };
3169
3170        session.peer.send(connection_id, update.clone())?;
3171    }
3172
3173    response.send(Ack {})?;
3174    Ok(())
3175}
3176
3177async fn reorder_channel(
3178    request: proto::ReorderChannel,
3179    response: Response<proto::ReorderChannel>,
3180    session: MessageContext,
3181) -> Result<()> {
3182    let channel_id = ChannelId::from_proto(request.channel_id);
3183    let direction = request.direction();
3184
3185    let updated_channels = session
3186        .db()
3187        .await
3188        .reorder_channel(channel_id, direction, session.user_id())
3189        .await?;
3190
3191    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3192        let connection_pool = session.connection_pool().await;
3193        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3194            let channels = updated_channels
3195                .iter()
3196                .filter_map(|channel| {
3197                    if role.can_see_channel(channel.visibility) {
3198                        Some(channel.to_proto())
3199                    } else {
3200                        None
3201                    }
3202                })
3203                .collect::<Vec<_>>();
3204
3205            if channels.is_empty() {
3206                continue;
3207            }
3208
3209            let update = proto::UpdateChannels {
3210                channels,
3211                ..Default::default()
3212            };
3213
3214            session.peer.send(connection_id, update.clone())?;
3215        }
3216    }
3217
3218    response.send(Ack {})?;
3219    Ok(())
3220}
3221
3222/// Get the list of channel members
3223async fn get_channel_members(
3224    request: proto::GetChannelMembers,
3225    response: Response<proto::GetChannelMembers>,
3226    session: MessageContext,
3227) -> Result<()> {
3228    let db = session.db().await;
3229    let channel_id = ChannelId::from_proto(request.channel_id);
3230    let limit = if request.limit == 0 {
3231        u16::MAX as u64
3232    } else {
3233        request.limit
3234    };
3235    let (members, users) = db
3236        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3237        .await?;
3238    response.send(proto::GetChannelMembersResponse { members, users })?;
3239    Ok(())
3240}
3241
3242/// Accept or decline a channel invitation.
3243async fn respond_to_channel_invite(
3244    request: proto::RespondToChannelInvite,
3245    response: Response<proto::RespondToChannelInvite>,
3246    session: MessageContext,
3247) -> Result<()> {
3248    let db = session.db().await;
3249    let channel_id = ChannelId::from_proto(request.channel_id);
3250    let RespondToChannelInvite {
3251        membership_update,
3252        notifications,
3253    } = db
3254        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3255        .await?;
3256
3257    let mut connection_pool = session.connection_pool().await;
3258    if let Some(membership_update) = membership_update {
3259        notify_membership_updated(
3260            &mut connection_pool,
3261            membership_update,
3262            session.user_id(),
3263            &session.peer,
3264        );
3265    } else {
3266        let update = proto::UpdateChannels {
3267            remove_channel_invitations: vec![channel_id.to_proto()],
3268            ..Default::default()
3269        };
3270
3271        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3272            session.peer.send(connection_id, update.clone())?;
3273        }
3274    };
3275
3276    send_notifications(&connection_pool, &session.peer, notifications);
3277
3278    response.send(proto::Ack {})?;
3279
3280    Ok(())
3281}
3282
3283/// Join the channels' room
3284async fn join_channel(
3285    request: proto::JoinChannel,
3286    response: Response<proto::JoinChannel>,
3287    session: MessageContext,
3288) -> Result<()> {
3289    let channel_id = ChannelId::from_proto(request.channel_id);
3290    join_channel_internal(channel_id, Box::new(response), session).await
3291}
3292
3293trait JoinChannelInternalResponse {
3294    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3295}
3296impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3297    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3298        Response::<proto::JoinChannel>::send(self, result)
3299    }
3300}
3301impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3302    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3303        Response::<proto::JoinRoom>::send(self, result)
3304    }
3305}
3306
3307async fn join_channel_internal(
3308    channel_id: ChannelId,
3309    response: Box<impl JoinChannelInternalResponse>,
3310    session: MessageContext,
3311) -> Result<()> {
3312    let joined_room = {
3313        let mut db = session.db().await;
3314        // If zed quits without leaving the room, and the user re-opens zed before the
3315        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3316        // room they were in.
3317        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3318            tracing::info!(
3319                stale_connection_id = %connection,
3320                "cleaning up stale connection",
3321            );
3322            drop(db);
3323            leave_room_for_session(&session, connection).await?;
3324            db = session.db().await;
3325        }
3326
3327        let (joined_room, membership_updated, role) = db
3328            .join_channel(channel_id, session.user_id(), session.connection_id)
3329            .await?;
3330
3331        let live_kit_connection_info =
3332            session
3333                .app_state
3334                .livekit_client
3335                .as_ref()
3336                .and_then(|live_kit| {
3337                    let (can_publish, token) = if role == ChannelRole::Guest {
3338                        (
3339                            false,
3340                            live_kit
3341                                .guest_token(
3342                                    &joined_room.room.livekit_room,
3343                                    &session.user_id().to_string(),
3344                                )
3345                                .trace_err()?,
3346                        )
3347                    } else {
3348                        (
3349                            true,
3350                            live_kit
3351                                .room_token(
3352                                    &joined_room.room.livekit_room,
3353                                    &session.user_id().to_string(),
3354                                )
3355                                .trace_err()?,
3356                        )
3357                    };
3358
3359                    Some(LiveKitConnectionInfo {
3360                        server_url: live_kit.url().into(),
3361                        token,
3362                        can_publish,
3363                    })
3364                });
3365
3366        response.send(proto::JoinRoomResponse {
3367            room: Some(joined_room.room.clone()),
3368            channel_id: joined_room
3369                .channel
3370                .as_ref()
3371                .map(|channel| channel.id.to_proto()),
3372            live_kit_connection_info,
3373        })?;
3374
3375        let mut connection_pool = session.connection_pool().await;
3376        if let Some(membership_updated) = membership_updated {
3377            notify_membership_updated(
3378                &mut connection_pool,
3379                membership_updated,
3380                session.user_id(),
3381                &session.peer,
3382            );
3383        }
3384
3385        room_updated(&joined_room.room, &session.peer);
3386
3387        joined_room
3388    };
3389
3390    channel_updated(
3391        &joined_room.channel.context("channel not returned")?,
3392        &joined_room.room,
3393        &session.peer,
3394        &*session.connection_pool().await,
3395    );
3396
3397    update_user_contacts(session.user_id(), &session).await?;
3398    Ok(())
3399}
3400
3401/// Start editing the channel notes
3402async fn join_channel_buffer(
3403    request: proto::JoinChannelBuffer,
3404    response: Response<proto::JoinChannelBuffer>,
3405    session: MessageContext,
3406) -> Result<()> {
3407    let db = session.db().await;
3408    let channel_id = ChannelId::from_proto(request.channel_id);
3409
3410    let open_response = db
3411        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3412        .await?;
3413
3414    let collaborators = open_response.collaborators.clone();
3415    response.send(open_response)?;
3416
3417    let update = UpdateChannelBufferCollaborators {
3418        channel_id: channel_id.to_proto(),
3419        collaborators: collaborators.clone(),
3420    };
3421    channel_buffer_updated(
3422        session.connection_id,
3423        collaborators
3424            .iter()
3425            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3426        &update,
3427        &session.peer,
3428    );
3429
3430    Ok(())
3431}
3432
3433/// Edit the channel notes
3434async fn update_channel_buffer(
3435    request: proto::UpdateChannelBuffer,
3436    session: MessageContext,
3437) -> Result<()> {
3438    let db = session.db().await;
3439    let channel_id = ChannelId::from_proto(request.channel_id);
3440
3441    let (collaborators, epoch, version) = db
3442        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3443        .await?;
3444
3445    channel_buffer_updated(
3446        session.connection_id,
3447        collaborators.clone(),
3448        &proto::UpdateChannelBuffer {
3449            channel_id: channel_id.to_proto(),
3450            operations: request.operations,
3451        },
3452        &session.peer,
3453    );
3454
3455    let pool = &*session.connection_pool().await;
3456
3457    let non_collaborators =
3458        pool.channel_connection_ids(channel_id)
3459            .filter_map(|(connection_id, _)| {
3460                if collaborators.contains(&connection_id) {
3461                    None
3462                } else {
3463                    Some(connection_id)
3464                }
3465            });
3466
3467    broadcast(None, non_collaborators, |peer_id| {
3468        session.peer.send(
3469            peer_id,
3470            proto::UpdateChannels {
3471                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3472                    channel_id: channel_id.to_proto(),
3473                    epoch: epoch as u64,
3474                    version: version.clone(),
3475                }],
3476                ..Default::default()
3477            },
3478        )
3479    });
3480
3481    Ok(())
3482}
3483
3484/// Rejoin the channel notes after a connection blip
3485async fn rejoin_channel_buffers(
3486    request: proto::RejoinChannelBuffers,
3487    response: Response<proto::RejoinChannelBuffers>,
3488    session: MessageContext,
3489) -> Result<()> {
3490    let db = session.db().await;
3491    let buffers = db
3492        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3493        .await?;
3494
3495    for rejoined_buffer in &buffers {
3496        let collaborators_to_notify = rejoined_buffer
3497            .buffer
3498            .collaborators
3499            .iter()
3500            .filter_map(|c| Some(c.peer_id?.into()));
3501        channel_buffer_updated(
3502            session.connection_id,
3503            collaborators_to_notify,
3504            &proto::UpdateChannelBufferCollaborators {
3505                channel_id: rejoined_buffer.buffer.channel_id,
3506                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3507            },
3508            &session.peer,
3509        );
3510    }
3511
3512    response.send(proto::RejoinChannelBuffersResponse {
3513        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3514    })?;
3515
3516    Ok(())
3517}
3518
3519/// Stop editing the channel notes
3520async fn leave_channel_buffer(
3521    request: proto::LeaveChannelBuffer,
3522    response: Response<proto::LeaveChannelBuffer>,
3523    session: MessageContext,
3524) -> Result<()> {
3525    let db = session.db().await;
3526    let channel_id = ChannelId::from_proto(request.channel_id);
3527
3528    let left_buffer = db
3529        .leave_channel_buffer(channel_id, session.connection_id)
3530        .await?;
3531
3532    response.send(Ack {})?;
3533
3534    channel_buffer_updated(
3535        session.connection_id,
3536        left_buffer.connections,
3537        &proto::UpdateChannelBufferCollaborators {
3538            channel_id: channel_id.to_proto(),
3539            collaborators: left_buffer.collaborators,
3540        },
3541        &session.peer,
3542    );
3543
3544    Ok(())
3545}
3546
3547fn channel_buffer_updated<T: EnvelopedMessage>(
3548    sender_id: ConnectionId,
3549    collaborators: impl IntoIterator<Item = ConnectionId>,
3550    message: &T,
3551    peer: &Peer,
3552) {
3553    broadcast(Some(sender_id), collaborators, |peer_id| {
3554        peer.send(peer_id, message.clone())
3555    });
3556}
3557
3558fn send_notifications(
3559    connection_pool: &ConnectionPool,
3560    peer: &Peer,
3561    notifications: db::NotificationBatch,
3562) {
3563    for (user_id, notification) in notifications {
3564        for connection_id in connection_pool.user_connection_ids(user_id) {
3565            if let Err(error) = peer.send(
3566                connection_id,
3567                proto::AddNotification {
3568                    notification: Some(notification.clone()),
3569                },
3570            ) {
3571                tracing::error!(
3572                    "failed to send notification to {:?} {}",
3573                    connection_id,
3574                    error
3575                );
3576            }
3577        }
3578    }
3579}
3580
3581/// Send a message to the channel
3582async fn send_channel_message(
3583    request: proto::SendChannelMessage,
3584    response: Response<proto::SendChannelMessage>,
3585    session: MessageContext,
3586) -> Result<()> {
3587    // Validate the message body.
3588    let body = request.body.trim().to_string();
3589    if body.len() > MAX_MESSAGE_LEN {
3590        return Err(anyhow!("message is too long"))?;
3591    }
3592    if body.is_empty() {
3593        return Err(anyhow!("message can't be blank"))?;
3594    }
3595
3596    // TODO: adjust mentions if body is trimmed
3597
3598    let timestamp = OffsetDateTime::now_utc();
3599    let nonce = request.nonce.context("nonce can't be blank")?;
3600
3601    let channel_id = ChannelId::from_proto(request.channel_id);
3602    let CreatedChannelMessage {
3603        message_id,
3604        participant_connection_ids,
3605        notifications,
3606    } = session
3607        .db()
3608        .await
3609        .create_channel_message(
3610            channel_id,
3611            session.user_id(),
3612            &body,
3613            &request.mentions,
3614            timestamp,
3615            nonce.clone().into(),
3616            request.reply_to_message_id.map(MessageId::from_proto),
3617        )
3618        .await?;
3619
3620    let message = proto::ChannelMessage {
3621        sender_id: session.user_id().to_proto(),
3622        id: message_id.to_proto(),
3623        body,
3624        mentions: request.mentions,
3625        timestamp: timestamp.unix_timestamp() as u64,
3626        nonce: Some(nonce),
3627        reply_to_message_id: request.reply_to_message_id,
3628        edited_at: None,
3629    };
3630    broadcast(
3631        Some(session.connection_id),
3632        participant_connection_ids.clone(),
3633        |connection| {
3634            session.peer.send(
3635                connection,
3636                proto::ChannelMessageSent {
3637                    channel_id: channel_id.to_proto(),
3638                    message: Some(message.clone()),
3639                },
3640            )
3641        },
3642    );
3643    response.send(proto::SendChannelMessageResponse {
3644        message: Some(message),
3645    })?;
3646
3647    let pool = &*session.connection_pool().await;
3648    let non_participants =
3649        pool.channel_connection_ids(channel_id)
3650            .filter_map(|(connection_id, _)| {
3651                if participant_connection_ids.contains(&connection_id) {
3652                    None
3653                } else {
3654                    Some(connection_id)
3655                }
3656            });
3657    broadcast(None, non_participants, |peer_id| {
3658        session.peer.send(
3659            peer_id,
3660            proto::UpdateChannels {
3661                latest_channel_message_ids: vec![proto::ChannelMessageId {
3662                    channel_id: channel_id.to_proto(),
3663                    message_id: message_id.to_proto(),
3664                }],
3665                ..Default::default()
3666            },
3667        )
3668    });
3669    send_notifications(pool, &session.peer, notifications);
3670
3671    Ok(())
3672}
3673
3674/// Delete a channel message
3675async fn remove_channel_message(
3676    request: proto::RemoveChannelMessage,
3677    response: Response<proto::RemoveChannelMessage>,
3678    session: MessageContext,
3679) -> Result<()> {
3680    let channel_id = ChannelId::from_proto(request.channel_id);
3681    let message_id = MessageId::from_proto(request.message_id);
3682    let (connection_ids, existing_notification_ids) = session
3683        .db()
3684        .await
3685        .remove_channel_message(channel_id, message_id, session.user_id())
3686        .await?;
3687
3688    broadcast(
3689        Some(session.connection_id),
3690        connection_ids,
3691        move |connection| {
3692            session.peer.send(connection, request.clone())?;
3693
3694            for notification_id in &existing_notification_ids {
3695                session.peer.send(
3696                    connection,
3697                    proto::DeleteNotification {
3698                        notification_id: (*notification_id).to_proto(),
3699                    },
3700                )?;
3701            }
3702
3703            Ok(())
3704        },
3705    );
3706    response.send(proto::Ack {})?;
3707    Ok(())
3708}
3709
3710async fn update_channel_message(
3711    request: proto::UpdateChannelMessage,
3712    response: Response<proto::UpdateChannelMessage>,
3713    session: MessageContext,
3714) -> Result<()> {
3715    let channel_id = ChannelId::from_proto(request.channel_id);
3716    let message_id = MessageId::from_proto(request.message_id);
3717    let updated_at = OffsetDateTime::now_utc();
3718    let UpdatedChannelMessage {
3719        message_id,
3720        participant_connection_ids,
3721        notifications,
3722        reply_to_message_id,
3723        timestamp,
3724        deleted_mention_notification_ids,
3725        updated_mention_notifications,
3726    } = session
3727        .db()
3728        .await
3729        .update_channel_message(
3730            channel_id,
3731            message_id,
3732            session.user_id(),
3733            request.body.as_str(),
3734            &request.mentions,
3735            updated_at,
3736        )
3737        .await?;
3738
3739    let nonce = request.nonce.clone().context("nonce can't be blank")?;
3740
3741    let message = proto::ChannelMessage {
3742        sender_id: session.user_id().to_proto(),
3743        id: message_id.to_proto(),
3744        body: request.body.clone(),
3745        mentions: request.mentions.clone(),
3746        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3747        nonce: Some(nonce),
3748        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3749        edited_at: Some(updated_at.unix_timestamp() as u64),
3750    };
3751
3752    response.send(proto::Ack {})?;
3753
3754    let pool = &*session.connection_pool().await;
3755    broadcast(
3756        Some(session.connection_id),
3757        participant_connection_ids,
3758        |connection| {
3759            session.peer.send(
3760                connection,
3761                proto::ChannelMessageUpdate {
3762                    channel_id: channel_id.to_proto(),
3763                    message: Some(message.clone()),
3764                },
3765            )?;
3766
3767            for notification_id in &deleted_mention_notification_ids {
3768                session.peer.send(
3769                    connection,
3770                    proto::DeleteNotification {
3771                        notification_id: (*notification_id).to_proto(),
3772                    },
3773                )?;
3774            }
3775
3776            for notification in &updated_mention_notifications {
3777                session.peer.send(
3778                    connection,
3779                    proto::UpdateNotification {
3780                        notification: Some(notification.clone()),
3781                    },
3782                )?;
3783            }
3784
3785            Ok(())
3786        },
3787    );
3788
3789    send_notifications(pool, &session.peer, notifications);
3790
3791    Ok(())
3792}
3793
3794/// Mark a channel message as read
3795async fn acknowledge_channel_message(
3796    request: proto::AckChannelMessage,
3797    session: MessageContext,
3798) -> Result<()> {
3799    let channel_id = ChannelId::from_proto(request.channel_id);
3800    let message_id = MessageId::from_proto(request.message_id);
3801    let notifications = session
3802        .db()
3803        .await
3804        .observe_channel_message(channel_id, session.user_id(), message_id)
3805        .await?;
3806    send_notifications(
3807        &*session.connection_pool().await,
3808        &session.peer,
3809        notifications,
3810    );
3811    Ok(())
3812}
3813
3814/// Mark a buffer version as synced
3815async fn acknowledge_buffer_version(
3816    request: proto::AckBufferOperation,
3817    session: MessageContext,
3818) -> Result<()> {
3819    let buffer_id = BufferId::from_proto(request.buffer_id);
3820    session
3821        .db()
3822        .await
3823        .observe_buffer_version(
3824            buffer_id,
3825            session.user_id(),
3826            request.epoch as i32,
3827            &request.version,
3828        )
3829        .await?;
3830    Ok(())
3831}
3832
3833/// Get a Supermaven API key for the user
3834async fn get_supermaven_api_key(
3835    _request: proto::GetSupermavenApiKey,
3836    response: Response<proto::GetSupermavenApiKey>,
3837    session: MessageContext,
3838) -> Result<()> {
3839    let user_id: String = session.user_id().to_string();
3840    if !session.is_staff() {
3841        return Err(anyhow!("supermaven not enabled for this account"))?;
3842    }
3843
3844    let email = session.email().context("user must have an email")?;
3845
3846    let supermaven_admin_api = session
3847        .supermaven_client
3848        .as_ref()
3849        .context("supermaven not configured")?;
3850
3851    let result = supermaven_admin_api
3852        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3853        .await?;
3854
3855    response.send(proto::GetSupermavenApiKeyResponse {
3856        api_key: result.api_key,
3857    })?;
3858
3859    Ok(())
3860}
3861
3862/// Start receiving chat updates for a channel
3863async fn join_channel_chat(
3864    request: proto::JoinChannelChat,
3865    response: Response<proto::JoinChannelChat>,
3866    session: MessageContext,
3867) -> Result<()> {
3868    let channel_id = ChannelId::from_proto(request.channel_id);
3869
3870    let db = session.db().await;
3871    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3872        .await?;
3873    let messages = db
3874        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3875        .await?;
3876    response.send(proto::JoinChannelChatResponse {
3877        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3878        messages,
3879    })?;
3880    Ok(())
3881}
3882
3883/// Stop receiving chat updates for a channel
3884async fn leave_channel_chat(
3885    request: proto::LeaveChannelChat,
3886    session: MessageContext,
3887) -> Result<()> {
3888    let channel_id = ChannelId::from_proto(request.channel_id);
3889    session
3890        .db()
3891        .await
3892        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3893        .await?;
3894    Ok(())
3895}
3896
3897/// Retrieve the chat history for a channel
3898async fn get_channel_messages(
3899    request: proto::GetChannelMessages,
3900    response: Response<proto::GetChannelMessages>,
3901    session: MessageContext,
3902) -> Result<()> {
3903    let channel_id = ChannelId::from_proto(request.channel_id);
3904    let messages = session
3905        .db()
3906        .await
3907        .get_channel_messages(
3908            channel_id,
3909            session.user_id(),
3910            MESSAGE_COUNT_PER_PAGE,
3911            Some(MessageId::from_proto(request.before_message_id)),
3912        )
3913        .await?;
3914    response.send(proto::GetChannelMessagesResponse {
3915        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3916        messages,
3917    })?;
3918    Ok(())
3919}
3920
3921/// Retrieve specific chat messages
3922async fn get_channel_messages_by_id(
3923    request: proto::GetChannelMessagesById,
3924    response: Response<proto::GetChannelMessagesById>,
3925    session: MessageContext,
3926) -> Result<()> {
3927    let message_ids = request
3928        .message_ids
3929        .iter()
3930        .map(|id| MessageId::from_proto(*id))
3931        .collect::<Vec<_>>();
3932    let messages = session
3933        .db()
3934        .await
3935        .get_channel_messages_by_id(session.user_id(), &message_ids)
3936        .await?;
3937    response.send(proto::GetChannelMessagesResponse {
3938        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3939        messages,
3940    })?;
3941    Ok(())
3942}
3943
3944/// Retrieve the current users notifications
3945async fn get_notifications(
3946    request: proto::GetNotifications,
3947    response: Response<proto::GetNotifications>,
3948    session: MessageContext,
3949) -> Result<()> {
3950    let notifications = session
3951        .db()
3952        .await
3953        .get_notifications(
3954            session.user_id(),
3955            NOTIFICATION_COUNT_PER_PAGE,
3956            request.before_id.map(db::NotificationId::from_proto),
3957        )
3958        .await?;
3959    response.send(proto::GetNotificationsResponse {
3960        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3961        notifications,
3962    })?;
3963    Ok(())
3964}
3965
3966/// Mark notifications as read
3967async fn mark_notification_as_read(
3968    request: proto::MarkNotificationRead,
3969    response: Response<proto::MarkNotificationRead>,
3970    session: MessageContext,
3971) -> Result<()> {
3972    let database = &session.db().await;
3973    let notifications = database
3974        .mark_notification_as_read_by_id(
3975            session.user_id(),
3976            NotificationId::from_proto(request.notification_id),
3977        )
3978        .await?;
3979    send_notifications(
3980        &*session.connection_pool().await,
3981        &session.peer,
3982        notifications,
3983    );
3984    response.send(proto::Ack {})?;
3985    Ok(())
3986}
3987
3988/// Accept the terms of service (tos) on behalf of the current user
3989async fn accept_terms_of_service(
3990    _request: proto::AcceptTermsOfService,
3991    response: Response<proto::AcceptTermsOfService>,
3992    session: MessageContext,
3993) -> Result<()> {
3994    let db = session.db().await;
3995
3996    let accepted_tos_at = Utc::now();
3997    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3998        .await?;
3999
4000    response.send(proto::AcceptTermsOfServiceResponse {
4001        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4002    })?;
4003
4004    Ok(())
4005}
4006
4007fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4008    let message = match message {
4009        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4010        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4011        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4012        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4013        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4014            code: frame.code.into(),
4015            reason: frame.reason.as_str().to_owned().into(),
4016        })),
4017        // We should never receive a frame while reading the message, according
4018        // to the `tungstenite` maintainers:
4019        //
4020        // > It cannot occur when you read messages from the WebSocket, but it
4021        // > can be used when you want to send the raw frames (e.g. you want to
4022        // > send the frames to the WebSocket without composing the full message first).
4023        // >
4024        // > — https://github.com/snapview/tungstenite-rs/issues/268
4025        TungsteniteMessage::Frame(_) => {
4026            bail!("received an unexpected frame while reading the message")
4027        }
4028    };
4029
4030    Ok(message)
4031}
4032
4033fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4034    match message {
4035        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4036        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4037        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4038        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4039        AxumMessage::Close(frame) => {
4040            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4041                code: frame.code.into(),
4042                reason: frame.reason.as_ref().into(),
4043            }))
4044        }
4045    }
4046}
4047
4048fn notify_membership_updated(
4049    connection_pool: &mut ConnectionPool,
4050    result: MembershipUpdated,
4051    user_id: UserId,
4052    peer: &Peer,
4053) {
4054    for membership in &result.new_channels.channel_memberships {
4055        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4056    }
4057    for channel_id in &result.removed_channels {
4058        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4059    }
4060
4061    let user_channels_update = proto::UpdateUserChannels {
4062        channel_memberships: result
4063            .new_channels
4064            .channel_memberships
4065            .iter()
4066            .map(|cm| proto::ChannelMembership {
4067                channel_id: cm.channel_id.to_proto(),
4068                role: cm.role.into(),
4069            })
4070            .collect(),
4071        ..Default::default()
4072    };
4073
4074    let mut update = build_channels_update(result.new_channels);
4075    update.delete_channels = result
4076        .removed_channels
4077        .into_iter()
4078        .map(|id| id.to_proto())
4079        .collect();
4080    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4081
4082    for connection_id in connection_pool.user_connection_ids(user_id) {
4083        peer.send(connection_id, user_channels_update.clone())
4084            .trace_err();
4085        peer.send(connection_id, update.clone()).trace_err();
4086    }
4087}
4088
4089fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4090    proto::UpdateUserChannels {
4091        channel_memberships: channels
4092            .channel_memberships
4093            .iter()
4094            .map(|m| proto::ChannelMembership {
4095                channel_id: m.channel_id.to_proto(),
4096                role: m.role.into(),
4097            })
4098            .collect(),
4099        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4100        observed_channel_message_id: channels.observed_channel_messages.clone(),
4101    }
4102}
4103
4104fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4105    let mut update = proto::UpdateChannels::default();
4106
4107    for channel in channels.channels {
4108        update.channels.push(channel.to_proto());
4109    }
4110
4111    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4112    update.latest_channel_message_ids = channels.latest_channel_messages;
4113
4114    for (channel_id, participants) in channels.channel_participants {
4115        update
4116            .channel_participants
4117            .push(proto::ChannelParticipants {
4118                channel_id: channel_id.to_proto(),
4119                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4120            });
4121    }
4122
4123    for channel in channels.invited_channels {
4124        update.channel_invitations.push(channel.to_proto());
4125    }
4126
4127    update
4128}
4129
4130fn build_initial_contacts_update(
4131    contacts: Vec<db::Contact>,
4132    pool: &ConnectionPool,
4133) -> proto::UpdateContacts {
4134    let mut update = proto::UpdateContacts::default();
4135
4136    for contact in contacts {
4137        match contact {
4138            db::Contact::Accepted { user_id, busy } => {
4139                update.contacts.push(contact_for_user(user_id, busy, pool));
4140            }
4141            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4142            db::Contact::Incoming { user_id } => {
4143                update
4144                    .incoming_requests
4145                    .push(proto::IncomingContactRequest {
4146                        requester_id: user_id.to_proto(),
4147                    })
4148            }
4149        }
4150    }
4151
4152    update
4153}
4154
4155fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4156    proto::Contact {
4157        user_id: user_id.to_proto(),
4158        online: pool.is_user_online(user_id),
4159        busy,
4160    }
4161}
4162
4163fn room_updated(room: &proto::Room, peer: &Peer) {
4164    broadcast(
4165        None,
4166        room.participants
4167            .iter()
4168            .filter_map(|participant| Some(participant.peer_id?.into())),
4169        |peer_id| {
4170            peer.send(
4171                peer_id,
4172                proto::RoomUpdated {
4173                    room: Some(room.clone()),
4174                },
4175            )
4176        },
4177    );
4178}
4179
4180fn channel_updated(
4181    channel: &db::channel::Model,
4182    room: &proto::Room,
4183    peer: &Peer,
4184    pool: &ConnectionPool,
4185) {
4186    let participants = room
4187        .participants
4188        .iter()
4189        .map(|p| p.user_id)
4190        .collect::<Vec<_>>();
4191
4192    broadcast(
4193        None,
4194        pool.channel_connection_ids(channel.root_id())
4195            .filter_map(|(channel_id, role)| {
4196                role.can_see_channel(channel.visibility)
4197                    .then_some(channel_id)
4198            }),
4199        |peer_id| {
4200            peer.send(
4201                peer_id,
4202                proto::UpdateChannels {
4203                    channel_participants: vec![proto::ChannelParticipants {
4204                        channel_id: channel.id.to_proto(),
4205                        participant_user_ids: participants.clone(),
4206                    }],
4207                    ..Default::default()
4208                },
4209            )
4210        },
4211    );
4212}
4213
4214async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4215    let db = session.db().await;
4216
4217    let contacts = db.get_contacts(user_id).await?;
4218    let busy = db.is_user_busy(user_id).await?;
4219
4220    let pool = session.connection_pool().await;
4221    let updated_contact = contact_for_user(user_id, busy, &pool);
4222    for contact in contacts {
4223        if let db::Contact::Accepted {
4224            user_id: contact_user_id,
4225            ..
4226        } = contact
4227        {
4228            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4229                session
4230                    .peer
4231                    .send(
4232                        contact_conn_id,
4233                        proto::UpdateContacts {
4234                            contacts: vec![updated_contact.clone()],
4235                            remove_contacts: Default::default(),
4236                            incoming_requests: Default::default(),
4237                            remove_incoming_requests: Default::default(),
4238                            outgoing_requests: Default::default(),
4239                            remove_outgoing_requests: Default::default(),
4240                        },
4241                    )
4242                    .trace_err();
4243            }
4244        }
4245    }
4246    Ok(())
4247}
4248
4249async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4250    let mut contacts_to_update = HashSet::default();
4251
4252    let room_id;
4253    let canceled_calls_to_user_ids;
4254    let livekit_room;
4255    let delete_livekit_room;
4256    let room;
4257    let channel;
4258
4259    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4260        contacts_to_update.insert(session.user_id());
4261
4262        for project in left_room.left_projects.values() {
4263            project_left(project, session);
4264        }
4265
4266        room_id = RoomId::from_proto(left_room.room.id);
4267        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4268        livekit_room = mem::take(&mut left_room.room.livekit_room);
4269        delete_livekit_room = left_room.deleted;
4270        room = mem::take(&mut left_room.room);
4271        channel = mem::take(&mut left_room.channel);
4272
4273        room_updated(&room, &session.peer);
4274    } else {
4275        return Ok(());
4276    }
4277
4278    if let Some(channel) = channel {
4279        channel_updated(
4280            &channel,
4281            &room,
4282            &session.peer,
4283            &*session.connection_pool().await,
4284        );
4285    }
4286
4287    {
4288        let pool = session.connection_pool().await;
4289        for canceled_user_id in canceled_calls_to_user_ids {
4290            for connection_id in pool.user_connection_ids(canceled_user_id) {
4291                session
4292                    .peer
4293                    .send(
4294                        connection_id,
4295                        proto::CallCanceled {
4296                            room_id: room_id.to_proto(),
4297                        },
4298                    )
4299                    .trace_err();
4300            }
4301            contacts_to_update.insert(canceled_user_id);
4302        }
4303    }
4304
4305    for contact_user_id in contacts_to_update {
4306        update_user_contacts(contact_user_id, session).await?;
4307    }
4308
4309    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4310        live_kit
4311            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4312            .await
4313            .trace_err();
4314
4315        if delete_livekit_room {
4316            live_kit.delete_room(livekit_room).await.trace_err();
4317        }
4318    }
4319
4320    Ok(())
4321}
4322
4323async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4324    let left_channel_buffers = session
4325        .db()
4326        .await
4327        .leave_channel_buffers(session.connection_id)
4328        .await?;
4329
4330    for left_buffer in left_channel_buffers {
4331        channel_buffer_updated(
4332            session.connection_id,
4333            left_buffer.connections,
4334            &proto::UpdateChannelBufferCollaborators {
4335                channel_id: left_buffer.channel_id.to_proto(),
4336                collaborators: left_buffer.collaborators,
4337            },
4338            &session.peer,
4339        );
4340    }
4341
4342    Ok(())
4343}
4344
4345fn project_left(project: &db::LeftProject, session: &Session) {
4346    for connection_id in &project.connection_ids {
4347        if project.should_unshare {
4348            session
4349                .peer
4350                .send(
4351                    *connection_id,
4352                    proto::UnshareProject {
4353                        project_id: project.id.to_proto(),
4354                    },
4355                )
4356                .trace_err();
4357        } else {
4358            session
4359                .peer
4360                .send(
4361                    *connection_id,
4362                    proto::RemoveProjectCollaborator {
4363                        project_id: project.id.to_proto(),
4364                        peer_id: Some(session.connection_id.into()),
4365                    },
4366                )
4367                .trace_err();
4368        }
4369    }
4370}
4371
4372pub trait ResultExt {
4373    type Ok;
4374
4375    fn trace_err(self) -> Option<Self::Ok>;
4376}
4377
4378impl<T, E> ResultExt for Result<T, E>
4379where
4380    E: std::fmt::Debug,
4381{
4382    type Ok = T;
4383
4384    #[track_caller]
4385    fn trace_err(self) -> Option<T> {
4386        match self {
4387            Ok(value) => Some(value),
4388            Err(error) => {
4389                tracing::error!("{:?}", error);
4390                None
4391            }
4392        }
4393    }
4394}