rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
  10    },
  11    executor::Executor,
  12};
  13use anyhow::{Context as _, anyhow, bail};
  14use async_tungstenite::tungstenite::{
  15    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  16};
  17use axum::headers::UserAgent;
  18use axum::{
  19    Extension, Router, TypedHeader,
  20    body::Body,
  21    extract::{
  22        ConnectInfo, WebSocketUpgrade,
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use core::fmt::{self, Debug, Formatter};
  34use futures::TryFutureExt as _;
  35use reqwest_client::ReqwestClient;
  36use rpc::proto::{MultiLspQuery, split_repository_update};
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38use tracing::Span;
  39
  40use futures::{
  41    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  42    stream::FuturesUnordered,
  43};
  44use prometheus::{IntGauge, register_int_gauge};
  45use rpc::{
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  49        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51};
  52use semantic_version::SemanticVersion;
  53use serde::{Serialize, Serializer};
  54use std::{
  55    any::TypeId,
  56    future::Future,
  57    marker::PhantomData,
  58    mem,
  59    net::SocketAddr,
  60    ops::{Deref, DerefMut},
  61    rc::Rc,
  62    sync::{
  63        Arc, OnceLock,
  64        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  65    },
  66    time::{Duration, Instant},
  67};
  68use tokio::sync::{Semaphore, watch};
  69use tower::ServiceBuilder;
  70use tracing::{
  71    Instrument,
  72    field::{self},
  73    info_span, instrument,
  74};
  75
  76pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  77
  78// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  79pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  80
  81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  82const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  83
  84static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  85
  86const TOTAL_DURATION_MS: &str = "total_duration_ms";
  87const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  88const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  89const HOST_WAITING_MS: &str = "host_waiting_ms";
  90
  91type MessageHandler =
  92    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  93
  94pub struct ConnectionGuard;
  95
  96impl ConnectionGuard {
  97    pub fn try_acquire() -> Result<Self, ()> {
  98        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  99        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 100            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 101            tracing::error!(
 102                "too many concurrent connections: {}",
 103                current_connections + 1
 104            );
 105            return Err(());
 106        }
 107        Ok(ConnectionGuard)
 108    }
 109}
 110
 111impl Drop for ConnectionGuard {
 112    fn drop(&mut self) {
 113        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 114    }
 115}
 116
 117struct Response<R> {
 118    peer: Arc<Peer>,
 119    receipt: Receipt<R>,
 120    responded: Arc<AtomicBool>,
 121}
 122
 123impl<R: RequestMessage> Response<R> {
 124    fn send(self, payload: R::Response) -> Result<()> {
 125        self.responded.store(true, SeqCst);
 126        self.peer.respond(self.receipt, payload)?;
 127        Ok(())
 128    }
 129}
 130
 131#[derive(Clone, Debug)]
 132pub enum Principal {
 133    User(User),
 134    Impersonated { user: User, admin: User },
 135}
 136
 137impl Principal {
 138    fn update_span(&self, span: &tracing::Span) {
 139        match &self {
 140            Principal::User(user) => {
 141                span.record("user_id", user.id.0);
 142                span.record("login", &user.github_login);
 143            }
 144            Principal::Impersonated { user, admin } => {
 145                span.record("user_id", user.id.0);
 146                span.record("login", &user.github_login);
 147                span.record("impersonator", &admin.github_login);
 148            }
 149        }
 150    }
 151}
 152
 153#[derive(Clone)]
 154struct MessageContext {
 155    session: Session,
 156    span: tracing::Span,
 157}
 158
 159impl Deref for MessageContext {
 160    type Target = Session;
 161
 162    fn deref(&self) -> &Self::Target {
 163        &self.session
 164    }
 165}
 166
 167impl MessageContext {
 168    pub fn forward_request<T: RequestMessage>(
 169        &self,
 170        receiver_id: ConnectionId,
 171        request: T,
 172    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 173        let request_start_time = Instant::now();
 174        let span = self.span.clone();
 175        tracing::info!("start forwarding request");
 176        self.peer
 177            .forward_request(self.connection_id, receiver_id, request)
 178            .inspect(move |_| {
 179                span.record(
 180                    HOST_WAITING_MS,
 181                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 182                );
 183            })
 184            .inspect_err(|_| tracing::error!("error forwarding request"))
 185            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 186    }
 187}
 188
 189#[derive(Clone)]
 190struct Session {
 191    principal: Principal,
 192    connection_id: ConnectionId,
 193    db: Arc<tokio::sync::Mutex<DbHandle>>,
 194    peer: Arc<Peer>,
 195    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 196    app_state: Arc<AppState>,
 197    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 198    /// The GeoIP country code for the user.
 199    #[allow(unused)]
 200    geoip_country_code: Option<String>,
 201    #[allow(unused)]
 202    system_id: Option<String>,
 203    _executor: Executor,
 204}
 205
 206impl Session {
 207    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 208        #[cfg(test)]
 209        tokio::task::yield_now().await;
 210        let guard = self.db.lock().await;
 211        #[cfg(test)]
 212        tokio::task::yield_now().await;
 213        guard
 214    }
 215
 216    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 217        #[cfg(test)]
 218        tokio::task::yield_now().await;
 219        let guard = self.connection_pool.lock();
 220        ConnectionPoolGuard {
 221            guard,
 222            _not_send: PhantomData,
 223        }
 224    }
 225
 226    fn is_staff(&self) -> bool {
 227        match &self.principal {
 228            Principal::User(user) => user.admin,
 229            Principal::Impersonated { .. } => true,
 230        }
 231    }
 232
 233    fn user_id(&self) -> UserId {
 234        match &self.principal {
 235            Principal::User(user) => user.id,
 236            Principal::Impersonated { user, .. } => user.id,
 237        }
 238    }
 239
 240    pub fn email(&self) -> Option<String> {
 241        match &self.principal {
 242            Principal::User(user) => user.email_address.clone(),
 243            Principal::Impersonated { user, .. } => user.email_address.clone(),
 244        }
 245    }
 246}
 247
 248impl Debug for Session {
 249    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 250        let mut result = f.debug_struct("Session");
 251        match &self.principal {
 252            Principal::User(user) => {
 253                result.field("user", &user.github_login);
 254            }
 255            Principal::Impersonated { user, admin } => {
 256                result.field("user", &user.github_login);
 257                result.field("impersonator", &admin.github_login);
 258            }
 259        }
 260        result.field("connection_id", &self.connection_id).finish()
 261    }
 262}
 263
 264struct DbHandle(Arc<Database>);
 265
 266impl Deref for DbHandle {
 267    type Target = Database;
 268
 269    fn deref(&self) -> &Self::Target {
 270        self.0.as_ref()
 271    }
 272}
 273
 274pub struct Server {
 275    id: parking_lot::Mutex<ServerId>,
 276    peer: Arc<Peer>,
 277    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 278    app_state: Arc<AppState>,
 279    handlers: HashMap<TypeId, MessageHandler>,
 280    teardown: watch::Sender<bool>,
 281}
 282
 283pub(crate) struct ConnectionPoolGuard<'a> {
 284    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 285    _not_send: PhantomData<Rc<()>>,
 286}
 287
 288#[derive(Serialize)]
 289pub struct ServerSnapshot<'a> {
 290    peer: &'a Peer,
 291    #[serde(serialize_with = "serialize_deref")]
 292    connection_pool: ConnectionPoolGuard<'a>,
 293}
 294
 295pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 296where
 297    S: Serializer,
 298    T: Deref<Target = U>,
 299    U: Serialize,
 300{
 301    Serialize::serialize(value.deref(), serializer)
 302}
 303
 304impl Server {
 305    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 306        let mut server = Self {
 307            id: parking_lot::Mutex::new(id),
 308            peer: Peer::new(id.0 as u32),
 309            app_state,
 310            connection_pool: Default::default(),
 311            handlers: Default::default(),
 312            teardown: watch::channel(false).0,
 313        };
 314
 315        server
 316            .add_request_handler(ping)
 317            .add_request_handler(create_room)
 318            .add_request_handler(join_room)
 319            .add_request_handler(rejoin_room)
 320            .add_request_handler(leave_room)
 321            .add_request_handler(set_room_participant_role)
 322            .add_request_handler(call)
 323            .add_request_handler(cancel_call)
 324            .add_message_handler(decline_call)
 325            .add_request_handler(update_participant_location)
 326            .add_request_handler(share_project)
 327            .add_message_handler(unshare_project)
 328            .add_request_handler(join_project)
 329            .add_message_handler(leave_project)
 330            .add_request_handler(update_project)
 331            .add_request_handler(update_worktree)
 332            .add_request_handler(update_repository)
 333            .add_request_handler(remove_repository)
 334            .add_message_handler(start_language_server)
 335            .add_message_handler(update_language_server)
 336            .add_message_handler(update_diagnostic_summary)
 337            .add_message_handler(update_worktree_settings)
 338            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 339            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 340            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 341            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 342            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 344            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 345            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 346            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 347            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 348            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 349            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 350            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 351            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 352            .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
 353            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 354            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 355            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 356            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 357            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 358            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 359            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 360            .add_request_handler(
 361                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 362            )
 363            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 364            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 365            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 366            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 367            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
 368            .add_request_handler(
 369                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 370            )
 371            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 372            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 373            .add_request_handler(
 374                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 375            )
 376            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 377            .add_request_handler(
 378                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 379            )
 380            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 381            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 382            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 383            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 384            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 385            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 386            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 387            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 388            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 389            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 390            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 391            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 392            .add_request_handler(
 393                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 394            )
 395            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 396            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 397            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 398            .add_request_handler(multi_lsp_query)
 399            .add_request_handler(lsp_query)
 400            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 401            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 402            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 403            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 404            .add_message_handler(create_buffer_for_peer)
 405            .add_request_handler(update_buffer)
 406            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 407            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 408            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 409            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 410            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 411            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 412            .add_message_handler(
 413                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 414            )
 415            .add_request_handler(get_users)
 416            .add_request_handler(fuzzy_search_users)
 417            .add_request_handler(request_contact)
 418            .add_request_handler(remove_contact)
 419            .add_request_handler(respond_to_contact_request)
 420            .add_message_handler(subscribe_to_channels)
 421            .add_request_handler(create_channel)
 422            .add_request_handler(delete_channel)
 423            .add_request_handler(invite_channel_member)
 424            .add_request_handler(remove_channel_member)
 425            .add_request_handler(set_channel_member_role)
 426            .add_request_handler(set_channel_visibility)
 427            .add_request_handler(rename_channel)
 428            .add_request_handler(join_channel_buffer)
 429            .add_request_handler(leave_channel_buffer)
 430            .add_message_handler(update_channel_buffer)
 431            .add_request_handler(rejoin_channel_buffers)
 432            .add_request_handler(get_channel_members)
 433            .add_request_handler(respond_to_channel_invite)
 434            .add_request_handler(join_channel)
 435            .add_request_handler(join_channel_chat)
 436            .add_message_handler(leave_channel_chat)
 437            .add_request_handler(send_channel_message)
 438            .add_request_handler(remove_channel_message)
 439            .add_request_handler(update_channel_message)
 440            .add_request_handler(get_channel_messages)
 441            .add_request_handler(get_channel_messages_by_id)
 442            .add_request_handler(get_notifications)
 443            .add_request_handler(mark_notification_as_read)
 444            .add_request_handler(move_channel)
 445            .add_request_handler(reorder_channel)
 446            .add_request_handler(follow)
 447            .add_message_handler(unfollow)
 448            .add_message_handler(update_followers)
 449            .add_message_handler(acknowledge_channel_message)
 450            .add_message_handler(acknowledge_buffer_version)
 451            .add_request_handler(get_supermaven_api_key)
 452            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 453            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 454            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 455            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 456            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 457            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 458            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 459            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 460            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 461            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 462            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 463            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 464            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 465            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 466            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 467            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 468            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 469            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 470            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 471            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 472            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 473            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 474            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 475            .add_message_handler(update_context)
 476            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 477            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
 478
 479        Arc::new(server)
 480    }
 481
 482    pub async fn start(&self) -> Result<()> {
 483        let server_id = *self.id.lock();
 484        let app_state = self.app_state.clone();
 485        let peer = self.peer.clone();
 486        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 487        let pool = self.connection_pool.clone();
 488        let livekit_client = self.app_state.livekit_client.clone();
 489
 490        let span = info_span!("start server");
 491        self.app_state.executor.spawn_detached(
 492            async move {
 493                tracing::info!("waiting for cleanup timeout");
 494                timeout.await;
 495                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 496
 497                app_state
 498                    .db
 499                    .delete_stale_channel_chat_participants(
 500                        &app_state.config.zed_environment,
 501                        server_id,
 502                    )
 503                    .await
 504                    .trace_err();
 505
 506                if let Some((room_ids, channel_ids)) = app_state
 507                    .db
 508                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 509                    .await
 510                    .trace_err()
 511                {
 512                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 513                    tracing::info!(
 514                        stale_channel_buffer_count = channel_ids.len(),
 515                        "retrieved stale channel buffers"
 516                    );
 517
 518                    for channel_id in channel_ids {
 519                        if let Some(refreshed_channel_buffer) = app_state
 520                            .db
 521                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 522                            .await
 523                            .trace_err()
 524                        {
 525                            for connection_id in refreshed_channel_buffer.connection_ids {
 526                                peer.send(
 527                                    connection_id,
 528                                    proto::UpdateChannelBufferCollaborators {
 529                                        channel_id: channel_id.to_proto(),
 530                                        collaborators: refreshed_channel_buffer
 531                                            .collaborators
 532                                            .clone(),
 533                                    },
 534                                )
 535                                .trace_err();
 536                            }
 537                        }
 538                    }
 539
 540                    for room_id in room_ids {
 541                        let mut contacts_to_update = HashSet::default();
 542                        let mut canceled_calls_to_user_ids = Vec::new();
 543                        let mut livekit_room = String::new();
 544                        let mut delete_livekit_room = false;
 545
 546                        if let Some(mut refreshed_room) = app_state
 547                            .db
 548                            .clear_stale_room_participants(room_id, server_id)
 549                            .await
 550                            .trace_err()
 551                        {
 552                            tracing::info!(
 553                                room_id = room_id.0,
 554                                new_participant_count = refreshed_room.room.participants.len(),
 555                                "refreshed room"
 556                            );
 557                            room_updated(&refreshed_room.room, &peer);
 558                            if let Some(channel) = refreshed_room.channel.as_ref() {
 559                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 560                            }
 561                            contacts_to_update
 562                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 563                            contacts_to_update
 564                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 565                            canceled_calls_to_user_ids =
 566                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 567                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 568                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 569                        }
 570
 571                        {
 572                            let pool = pool.lock();
 573                            for canceled_user_id in canceled_calls_to_user_ids {
 574                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 575                                    peer.send(
 576                                        connection_id,
 577                                        proto::CallCanceled {
 578                                            room_id: room_id.to_proto(),
 579                                        },
 580                                    )
 581                                    .trace_err();
 582                                }
 583                            }
 584                        }
 585
 586                        for user_id in contacts_to_update {
 587                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 588                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 589                            if let Some((busy, contacts)) = busy.zip(contacts) {
 590                                let pool = pool.lock();
 591                                let updated_contact = contact_for_user(user_id, busy, &pool);
 592                                for contact in contacts {
 593                                    if let db::Contact::Accepted {
 594                                        user_id: contact_user_id,
 595                                        ..
 596                                    } = contact
 597                                    {
 598                                        for contact_conn_id in
 599                                            pool.user_connection_ids(contact_user_id)
 600                                        {
 601                                            peer.send(
 602                                                contact_conn_id,
 603                                                proto::UpdateContacts {
 604                                                    contacts: vec![updated_contact.clone()],
 605                                                    remove_contacts: Default::default(),
 606                                                    incoming_requests: Default::default(),
 607                                                    remove_incoming_requests: Default::default(),
 608                                                    outgoing_requests: Default::default(),
 609                                                    remove_outgoing_requests: Default::default(),
 610                                                },
 611                                            )
 612                                            .trace_err();
 613                                        }
 614                                    }
 615                                }
 616                            }
 617                        }
 618
 619                        if let Some(live_kit) = livekit_client.as_ref()
 620                            && delete_livekit_room
 621                        {
 622                            live_kit.delete_room(livekit_room).await.trace_err();
 623                        }
 624                    }
 625                }
 626
 627                app_state
 628                    .db
 629                    .delete_stale_channel_chat_participants(
 630                        &app_state.config.zed_environment,
 631                        server_id,
 632                    )
 633                    .await
 634                    .trace_err();
 635
 636                app_state
 637                    .db
 638                    .clear_old_worktree_entries(server_id)
 639                    .await
 640                    .trace_err();
 641
 642                app_state
 643                    .db
 644                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 645                    .await
 646                    .trace_err();
 647            }
 648            .instrument(span),
 649        );
 650        Ok(())
 651    }
 652
 653    pub fn teardown(&self) {
 654        self.peer.teardown();
 655        self.connection_pool.lock().reset();
 656        let _ = self.teardown.send(true);
 657    }
 658
 659    #[cfg(test)]
 660    pub fn reset(&self, id: ServerId) {
 661        self.teardown();
 662        *self.id.lock() = id;
 663        self.peer.reset(id.0 as u32);
 664        let _ = self.teardown.send(false);
 665    }
 666
 667    #[cfg(test)]
 668    pub fn id(&self) -> ServerId {
 669        *self.id.lock()
 670    }
 671
 672    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 673    where
 674        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 675        Fut: 'static + Send + Future<Output = Result<()>>,
 676        M: EnvelopedMessage,
 677    {
 678        let prev_handler = self.handlers.insert(
 679            TypeId::of::<M>(),
 680            Box::new(move |envelope, session, span| {
 681                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 682                let received_at = envelope.received_at;
 683                tracing::info!("message received");
 684                let start_time = Instant::now();
 685                let future = (handler)(
 686                    *envelope,
 687                    MessageContext {
 688                        session,
 689                        span: span.clone(),
 690                    },
 691                );
 692                async move {
 693                    let result = future.await;
 694                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 695                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 696                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 697                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 698                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 699                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 700                    match result {
 701                        Err(error) => {
 702                            tracing::error!(?error, "error handling message")
 703                        }
 704                        Ok(()) => tracing::info!("finished handling message"),
 705                    }
 706                }
 707                .boxed()
 708            }),
 709        );
 710        if prev_handler.is_some() {
 711            panic!("registered a handler for the same message twice");
 712        }
 713        self
 714    }
 715
 716    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 717    where
 718        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 719        Fut: 'static + Send + Future<Output = Result<()>>,
 720        M: EnvelopedMessage,
 721    {
 722        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 723        self
 724    }
 725
 726    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 727    where
 728        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 729        Fut: Send + Future<Output = Result<()>>,
 730        M: RequestMessage,
 731    {
 732        let handler = Arc::new(handler);
 733        self.add_handler(move |envelope, session| {
 734            let receipt = envelope.receipt();
 735            let handler = handler.clone();
 736            async move {
 737                let peer = session.peer.clone();
 738                let responded = Arc::new(AtomicBool::default());
 739                let response = Response {
 740                    peer: peer.clone(),
 741                    responded: responded.clone(),
 742                    receipt,
 743                };
 744                match (handler)(envelope.payload, response, session).await {
 745                    Ok(()) => {
 746                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 747                            Ok(())
 748                        } else {
 749                            Err(anyhow!("handler did not send a response"))?
 750                        }
 751                    }
 752                    Err(error) => {
 753                        let proto_err = match &error {
 754                            Error::Internal(err) => err.to_proto(),
 755                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 756                        };
 757                        peer.respond_with_error(receipt, proto_err)?;
 758                        Err(error)
 759                    }
 760                }
 761            }
 762        })
 763    }
 764
 765    pub fn handle_connection(
 766        self: &Arc<Self>,
 767        connection: Connection,
 768        address: String,
 769        principal: Principal,
 770        zed_version: ZedVersion,
 771        release_channel: Option<String>,
 772        user_agent: Option<String>,
 773        geoip_country_code: Option<String>,
 774        system_id: Option<String>,
 775        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 776        executor: Executor,
 777        connection_guard: Option<ConnectionGuard>,
 778    ) -> impl Future<Output = ()> + use<> {
 779        let this = self.clone();
 780        let span = info_span!("handle connection", %address,
 781            connection_id=field::Empty,
 782            user_id=field::Empty,
 783            login=field::Empty,
 784            impersonator=field::Empty,
 785            user_agent=field::Empty,
 786            geoip_country_code=field::Empty,
 787            release_channel=field::Empty,
 788        );
 789        principal.update_span(&span);
 790        if let Some(user_agent) = user_agent {
 791            span.record("user_agent", user_agent);
 792        }
 793        if let Some(release_channel) = release_channel {
 794            span.record("release_channel", release_channel);
 795        }
 796
 797        if let Some(country_code) = geoip_country_code.as_ref() {
 798            span.record("geoip_country_code", country_code);
 799        }
 800
 801        let mut teardown = self.teardown.subscribe();
 802        async move {
 803            if *teardown.borrow() {
 804                tracing::error!("server is tearing down");
 805                return;
 806            }
 807
 808            let (connection_id, handle_io, mut incoming_rx) =
 809                this.peer.add_connection(connection, {
 810                    let executor = executor.clone();
 811                    move |duration| executor.sleep(duration)
 812                });
 813            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 814
 815            tracing::info!("connection opened");
 816
 817            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 818            let http_client = match ReqwestClient::user_agent(&user_agent) {
 819                Ok(http_client) => Arc::new(http_client),
 820                Err(error) => {
 821                    tracing::error!(?error, "failed to create HTTP client");
 822                    return;
 823                }
 824            };
 825
 826            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
 827                |supermaven_admin_api_key| {
 828                    Arc::new(SupermavenAdminApi::new(
 829                        supermaven_admin_api_key.to_string(),
 830                        http_client.clone(),
 831                    ))
 832                },
 833            );
 834
 835            let session = Session {
 836                principal: principal.clone(),
 837                connection_id,
 838                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 839                peer: this.peer.clone(),
 840                connection_pool: this.connection_pool.clone(),
 841                app_state: this.app_state.clone(),
 842                geoip_country_code,
 843                system_id,
 844                _executor: executor.clone(),
 845                supermaven_client,
 846            };
 847
 848            if let Err(error) = this
 849                .send_initial_client_update(
 850                    connection_id,
 851                    zed_version,
 852                    send_connection_id,
 853                    &session,
 854                )
 855                .await
 856            {
 857                tracing::error!(?error, "failed to send initial client update");
 858                return;
 859            }
 860            drop(connection_guard);
 861
 862            let handle_io = handle_io.fuse();
 863            futures::pin_mut!(handle_io);
 864
 865            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 866            // This prevents deadlocks when e.g., client A performs a request to client B and
 867            // client B performs a request to client A. If both clients stop processing further
 868            // messages until their respective request completes, they won't have a chance to
 869            // respond to the other client's request and cause a deadlock.
 870            //
 871            // This arrangement ensures we will attempt to process earlier messages first, but fall
 872            // back to processing messages arrived later in the spirit of making progress.
 873            const MAX_CONCURRENT_HANDLERS: usize = 256;
 874            let mut foreground_message_handlers = FuturesUnordered::new();
 875            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 876            let get_concurrent_handlers = {
 877                let concurrent_handlers = concurrent_handlers.clone();
 878                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 879            };
 880            loop {
 881                let next_message = async {
 882                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 883                    let message = incoming_rx.next().await;
 884                    // Cache the concurrent_handlers here, so that we know what the
 885                    // queue looks like as each handler starts
 886                    (permit, message, get_concurrent_handlers())
 887                }
 888                .fuse();
 889                futures::pin_mut!(next_message);
 890                futures::select_biased! {
 891                    _ = teardown.changed().fuse() => return,
 892                    result = handle_io => {
 893                        if let Err(error) = result {
 894                            tracing::error!(?error, "error handling I/O");
 895                        }
 896                        break;
 897                    }
 898                    _ = foreground_message_handlers.next() => {}
 899                    next_message = next_message => {
 900                        let (permit, message, concurrent_handlers) = next_message;
 901                        if let Some(message) = message {
 902                            let type_name = message.payload_type_name();
 903                            // note: we copy all the fields from the parent span so we can query them in the logs.
 904                            // (https://github.com/tokio-rs/tracing/issues/2670).
 905                            let span = tracing::info_span!("receive message",
 906                                %connection_id,
 907                                %address,
 908                                type_name,
 909                                concurrent_handlers,
 910                                user_id=field::Empty,
 911                                login=field::Empty,
 912                                impersonator=field::Empty,
 913                                // todo(lsp) remove after Zed Stable hits v0.204.x
 914                                multi_lsp_query_request=field::Empty,
 915                                lsp_query_request=field::Empty,
 916                                release_channel=field::Empty,
 917                                { TOTAL_DURATION_MS }=field::Empty,
 918                                { PROCESSING_DURATION_MS }=field::Empty,
 919                                { QUEUE_DURATION_MS }=field::Empty,
 920                                { HOST_WAITING_MS }=field::Empty
 921                            );
 922                            principal.update_span(&span);
 923                            let span_enter = span.enter();
 924                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 925                                let is_background = message.is_background();
 926                                let handle_message = (handler)(message, session.clone(), span.clone());
 927                                drop(span_enter);
 928
 929                                let handle_message = async move {
 930                                    handle_message.await;
 931                                    drop(permit);
 932                                }.instrument(span);
 933                                if is_background {
 934                                    executor.spawn_detached(handle_message);
 935                                } else {
 936                                    foreground_message_handlers.push(handle_message);
 937                                }
 938                            } else {
 939                                tracing::error!("no message handler");
 940                            }
 941                        } else {
 942                            tracing::info!("connection closed");
 943                            break;
 944                        }
 945                    }
 946                }
 947            }
 948
 949            drop(foreground_message_handlers);
 950            let concurrent_handlers = get_concurrent_handlers();
 951            tracing::info!(concurrent_handlers, "signing out");
 952            if let Err(error) = connection_lost(session, teardown, executor).await {
 953                tracing::error!(?error, "error signing out");
 954            }
 955        }
 956        .instrument(span)
 957    }
 958
 959    async fn send_initial_client_update(
 960        &self,
 961        connection_id: ConnectionId,
 962        zed_version: ZedVersion,
 963        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 964        session: &Session,
 965    ) -> Result<()> {
 966        self.peer.send(
 967            connection_id,
 968            proto::Hello {
 969                peer_id: Some(connection_id.into()),
 970            },
 971        )?;
 972        tracing::info!("sent hello message");
 973        if let Some(send_connection_id) = send_connection_id.take() {
 974            let _ = send_connection_id.send(connection_id);
 975        }
 976
 977        match &session.principal {
 978            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 979                if !user.connected_once {
 980                    self.peer.send(connection_id, proto::ShowContacts {})?;
 981                    self.app_state
 982                        .db
 983                        .set_user_connected_once(user.id, true)
 984                        .await?;
 985                }
 986
 987                let contacts = self.app_state.db.get_contacts(user.id).await?;
 988
 989                {
 990                    let mut pool = self.connection_pool.lock();
 991                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 992                    self.peer.send(
 993                        connection_id,
 994                        build_initial_contacts_update(contacts, &pool),
 995                    )?;
 996                }
 997
 998                if should_auto_subscribe_to_channels(zed_version) {
 999                    subscribe_user_to_channels(user.id, session).await?;
1000                }
1001
1002                if let Some(incoming_call) =
1003                    self.app_state.db.incoming_call_for_user(user.id).await?
1004                {
1005                    self.peer.send(connection_id, incoming_call)?;
1006                }
1007
1008                update_user_contacts(user.id, session).await?;
1009            }
1010        }
1011
1012        Ok(())
1013    }
1014
1015    pub async fn invite_code_redeemed(
1016        self: &Arc<Self>,
1017        inviter_id: UserId,
1018        invitee_id: UserId,
1019    ) -> Result<()> {
1020        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1021            && let Some(code) = &user.invite_code
1022        {
1023            let pool = self.connection_pool.lock();
1024            let invitee_contact = contact_for_user(invitee_id, false, &pool);
1025            for connection_id in pool.user_connection_ids(inviter_id) {
1026                self.peer.send(
1027                    connection_id,
1028                    proto::UpdateContacts {
1029                        contacts: vec![invitee_contact.clone()],
1030                        ..Default::default()
1031                    },
1032                )?;
1033                self.peer.send(
1034                    connection_id,
1035                    proto::UpdateInviteInfo {
1036                        url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1037                        count: user.invite_count as u32,
1038                    },
1039                )?;
1040            }
1041        }
1042        Ok(())
1043    }
1044
1045    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1046        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1047            && let Some(invite_code) = &user.invite_code
1048        {
1049            let pool = self.connection_pool.lock();
1050            for connection_id in pool.user_connection_ids(user_id) {
1051                self.peer.send(
1052                    connection_id,
1053                    proto::UpdateInviteInfo {
1054                        url: format!(
1055                            "{}{}",
1056                            self.app_state.config.invite_link_prefix, invite_code
1057                        ),
1058                        count: user.invite_count as u32,
1059                    },
1060                )?;
1061            }
1062        }
1063        Ok(())
1064    }
1065
1066    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1067        ServerSnapshot {
1068            connection_pool: ConnectionPoolGuard {
1069                guard: self.connection_pool.lock(),
1070                _not_send: PhantomData,
1071            },
1072            peer: &self.peer,
1073        }
1074    }
1075}
1076
1077impl Deref for ConnectionPoolGuard<'_> {
1078    type Target = ConnectionPool;
1079
1080    fn deref(&self) -> &Self::Target {
1081        &self.guard
1082    }
1083}
1084
1085impl DerefMut for ConnectionPoolGuard<'_> {
1086    fn deref_mut(&mut self) -> &mut Self::Target {
1087        &mut self.guard
1088    }
1089}
1090
1091impl Drop for ConnectionPoolGuard<'_> {
1092    fn drop(&mut self) {
1093        #[cfg(test)]
1094        self.check_invariants();
1095    }
1096}
1097
1098fn broadcast<F>(
1099    sender_id: Option<ConnectionId>,
1100    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1101    mut f: F,
1102) where
1103    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1104{
1105    for receiver_id in receiver_ids {
1106        if Some(receiver_id) != sender_id
1107            && let Err(error) = f(receiver_id)
1108        {
1109            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1110        }
1111    }
1112}
1113
1114pub struct ProtocolVersion(u32);
1115
1116impl Header for ProtocolVersion {
1117    fn name() -> &'static HeaderName {
1118        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1119        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1120    }
1121
1122    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1123    where
1124        Self: Sized,
1125        I: Iterator<Item = &'i axum::http::HeaderValue>,
1126    {
1127        let version = values
1128            .next()
1129            .ok_or_else(axum::headers::Error::invalid)?
1130            .to_str()
1131            .map_err(|_| axum::headers::Error::invalid())?
1132            .parse()
1133            .map_err(|_| axum::headers::Error::invalid())?;
1134        Ok(Self(version))
1135    }
1136
1137    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1138        values.extend([self.0.to_string().parse().unwrap()]);
1139    }
1140}
1141
1142pub struct AppVersionHeader(SemanticVersion);
1143impl Header for AppVersionHeader {
1144    fn name() -> &'static HeaderName {
1145        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1146        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1147    }
1148
1149    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1150    where
1151        Self: Sized,
1152        I: Iterator<Item = &'i axum::http::HeaderValue>,
1153    {
1154        let version = values
1155            .next()
1156            .ok_or_else(axum::headers::Error::invalid)?
1157            .to_str()
1158            .map_err(|_| axum::headers::Error::invalid())?
1159            .parse()
1160            .map_err(|_| axum::headers::Error::invalid())?;
1161        Ok(Self(version))
1162    }
1163
1164    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1165        values.extend([self.0.to_string().parse().unwrap()]);
1166    }
1167}
1168
1169#[derive(Debug)]
1170pub struct ReleaseChannelHeader(String);
1171
1172impl Header for ReleaseChannelHeader {
1173    fn name() -> &'static HeaderName {
1174        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1175        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1176    }
1177
1178    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1179    where
1180        Self: Sized,
1181        I: Iterator<Item = &'i axum::http::HeaderValue>,
1182    {
1183        Ok(Self(
1184            values
1185                .next()
1186                .ok_or_else(axum::headers::Error::invalid)?
1187                .to_str()
1188                .map_err(|_| axum::headers::Error::invalid())?
1189                .to_owned(),
1190        ))
1191    }
1192
1193    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1194        values.extend([self.0.parse().unwrap()]);
1195    }
1196}
1197
1198pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1199    Router::new()
1200        .route("/rpc", get(handle_websocket_request))
1201        .layer(
1202            ServiceBuilder::new()
1203                .layer(Extension(server.app_state.clone()))
1204                .layer(middleware::from_fn(auth::validate_header)),
1205        )
1206        .route("/metrics", get(handle_metrics))
1207        .layer(Extension(server))
1208}
1209
1210pub async fn handle_websocket_request(
1211    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1212    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1213    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1214    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1215    Extension(server): Extension<Arc<Server>>,
1216    Extension(principal): Extension<Principal>,
1217    user_agent: Option<TypedHeader<UserAgent>>,
1218    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1219    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1220    ws: WebSocketUpgrade,
1221) -> axum::response::Response {
1222    if protocol_version != rpc::PROTOCOL_VERSION {
1223        return (
1224            StatusCode::UPGRADE_REQUIRED,
1225            "client must be upgraded".to_string(),
1226        )
1227            .into_response();
1228    }
1229
1230    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1231        return (
1232            StatusCode::UPGRADE_REQUIRED,
1233            "no version header found".to_string(),
1234        )
1235            .into_response();
1236    };
1237
1238    let release_channel = release_channel_header.map(|header| header.0.0);
1239
1240    if !version.can_collaborate() {
1241        return (
1242            StatusCode::UPGRADE_REQUIRED,
1243            "client must be upgraded".to_string(),
1244        )
1245            .into_response();
1246    }
1247
1248    let socket_address = socket_address.to_string();
1249
1250    // Acquire connection guard before WebSocket upgrade
1251    let connection_guard = match ConnectionGuard::try_acquire() {
1252        Ok(guard) => guard,
1253        Err(()) => {
1254            return (
1255                StatusCode::SERVICE_UNAVAILABLE,
1256                "Too many concurrent connections",
1257            )
1258                .into_response();
1259        }
1260    };
1261
1262    ws.on_upgrade(move |socket| {
1263        let socket = socket
1264            .map_ok(to_tungstenite_message)
1265            .err_into()
1266            .with(|message| async move { to_axum_message(message) });
1267        let connection = Connection::new(Box::pin(socket));
1268        async move {
1269            server
1270                .handle_connection(
1271                    connection,
1272                    socket_address,
1273                    principal,
1274                    version,
1275                    release_channel,
1276                    user_agent.map(|header| header.to_string()),
1277                    country_code_header.map(|header| header.to_string()),
1278                    system_id_header.map(|header| header.to_string()),
1279                    None,
1280                    Executor::Production,
1281                    Some(connection_guard),
1282                )
1283                .await;
1284        }
1285    })
1286}
1287
1288pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1289    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1290    let connections_metric = CONNECTIONS_METRIC
1291        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1292
1293    let connections = server
1294        .connection_pool
1295        .lock()
1296        .connections()
1297        .filter(|connection| !connection.admin)
1298        .count();
1299    connections_metric.set(connections as _);
1300
1301    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1302    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1303        register_int_gauge!(
1304            "shared_projects",
1305            "number of open projects with one or more guests"
1306        )
1307        .unwrap()
1308    });
1309
1310    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1311    shared_projects_metric.set(shared_projects as _);
1312
1313    let encoder = prometheus::TextEncoder::new();
1314    let metric_families = prometheus::gather();
1315    let encoded_metrics = encoder
1316        .encode_to_string(&metric_families)
1317        .map_err(|err| anyhow!("{err}"))?;
1318    Ok(encoded_metrics)
1319}
1320
1321#[instrument(err, skip(executor))]
1322async fn connection_lost(
1323    session: Session,
1324    mut teardown: watch::Receiver<bool>,
1325    executor: Executor,
1326) -> Result<()> {
1327    session.peer.disconnect(session.connection_id);
1328    session
1329        .connection_pool()
1330        .await
1331        .remove_connection(session.connection_id)?;
1332
1333    session
1334        .db()
1335        .await
1336        .connection_lost(session.connection_id)
1337        .await
1338        .trace_err();
1339
1340    futures::select_biased! {
1341        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1342
1343            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1344            leave_room_for_session(&session, session.connection_id).await.trace_err();
1345            leave_channel_buffers_for_session(&session)
1346                .await
1347                .trace_err();
1348
1349            if !session
1350                .connection_pool()
1351                .await
1352                .is_user_online(session.user_id())
1353            {
1354                let db = session.db().await;
1355                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1356                    room_updated(&room, &session.peer);
1357                }
1358            }
1359
1360            update_user_contacts(session.user_id(), &session).await?;
1361        },
1362        _ = teardown.changed().fuse() => {}
1363    }
1364
1365    Ok(())
1366}
1367
1368/// Acknowledges a ping from a client, used to keep the connection alive.
1369async fn ping(
1370    _: proto::Ping,
1371    response: Response<proto::Ping>,
1372    _session: MessageContext,
1373) -> Result<()> {
1374    response.send(proto::Ack {})?;
1375    Ok(())
1376}
1377
1378/// Creates a new room for calling (outside of channels)
1379async fn create_room(
1380    _request: proto::CreateRoom,
1381    response: Response<proto::CreateRoom>,
1382    session: MessageContext,
1383) -> Result<()> {
1384    let livekit_room = nanoid::nanoid!(30);
1385
1386    let live_kit_connection_info = util::maybe!(async {
1387        let live_kit = session.app_state.livekit_client.as_ref();
1388        let live_kit = live_kit?;
1389        let user_id = session.user_id().to_string();
1390
1391        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1392
1393        Some(proto::LiveKitConnectionInfo {
1394            server_url: live_kit.url().into(),
1395            token,
1396            can_publish: true,
1397        })
1398    })
1399    .await;
1400
1401    let room = session
1402        .db()
1403        .await
1404        .create_room(session.user_id(), session.connection_id, &livekit_room)
1405        .await?;
1406
1407    response.send(proto::CreateRoomResponse {
1408        room: Some(room.clone()),
1409        live_kit_connection_info,
1410    })?;
1411
1412    update_user_contacts(session.user_id(), &session).await?;
1413    Ok(())
1414}
1415
1416/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1417async fn join_room(
1418    request: proto::JoinRoom,
1419    response: Response<proto::JoinRoom>,
1420    session: MessageContext,
1421) -> Result<()> {
1422    let room_id = RoomId::from_proto(request.id);
1423
1424    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1425
1426    if let Some(channel_id) = channel_id {
1427        return join_channel_internal(channel_id, Box::new(response), session).await;
1428    }
1429
1430    let joined_room = {
1431        let room = session
1432            .db()
1433            .await
1434            .join_room(room_id, session.user_id(), session.connection_id)
1435            .await?;
1436        room_updated(&room.room, &session.peer);
1437        room.into_inner()
1438    };
1439
1440    for connection_id in session
1441        .connection_pool()
1442        .await
1443        .user_connection_ids(session.user_id())
1444    {
1445        session
1446            .peer
1447            .send(
1448                connection_id,
1449                proto::CallCanceled {
1450                    room_id: room_id.to_proto(),
1451                },
1452            )
1453            .trace_err();
1454    }
1455
1456    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1457    {
1458        live_kit
1459            .room_token(
1460                &joined_room.room.livekit_room,
1461                &session.user_id().to_string(),
1462            )
1463            .trace_err()
1464            .map(|token| proto::LiveKitConnectionInfo {
1465                server_url: live_kit.url().into(),
1466                token,
1467                can_publish: true,
1468            })
1469    } else {
1470        None
1471    };
1472
1473    response.send(proto::JoinRoomResponse {
1474        room: Some(joined_room.room),
1475        channel_id: None,
1476        live_kit_connection_info,
1477    })?;
1478
1479    update_user_contacts(session.user_id(), &session).await?;
1480    Ok(())
1481}
1482
1483/// Rejoin room is used to reconnect to a room after connection errors.
1484async fn rejoin_room(
1485    request: proto::RejoinRoom,
1486    response: Response<proto::RejoinRoom>,
1487    session: MessageContext,
1488) -> Result<()> {
1489    let room;
1490    let channel;
1491    {
1492        let mut rejoined_room = session
1493            .db()
1494            .await
1495            .rejoin_room(request, session.user_id(), session.connection_id)
1496            .await?;
1497
1498        response.send(proto::RejoinRoomResponse {
1499            room: Some(rejoined_room.room.clone()),
1500            reshared_projects: rejoined_room
1501                .reshared_projects
1502                .iter()
1503                .map(|project| proto::ResharedProject {
1504                    id: project.id.to_proto(),
1505                    collaborators: project
1506                        .collaborators
1507                        .iter()
1508                        .map(|collaborator| collaborator.to_proto())
1509                        .collect(),
1510                })
1511                .collect(),
1512            rejoined_projects: rejoined_room
1513                .rejoined_projects
1514                .iter()
1515                .map(|rejoined_project| rejoined_project.to_proto())
1516                .collect(),
1517        })?;
1518        room_updated(&rejoined_room.room, &session.peer);
1519
1520        for project in &rejoined_room.reshared_projects {
1521            for collaborator in &project.collaborators {
1522                session
1523                    .peer
1524                    .send(
1525                        collaborator.connection_id,
1526                        proto::UpdateProjectCollaborator {
1527                            project_id: project.id.to_proto(),
1528                            old_peer_id: Some(project.old_connection_id.into()),
1529                            new_peer_id: Some(session.connection_id.into()),
1530                        },
1531                    )
1532                    .trace_err();
1533            }
1534
1535            broadcast(
1536                Some(session.connection_id),
1537                project
1538                    .collaborators
1539                    .iter()
1540                    .map(|collaborator| collaborator.connection_id),
1541                |connection_id| {
1542                    session.peer.forward_send(
1543                        session.connection_id,
1544                        connection_id,
1545                        proto::UpdateProject {
1546                            project_id: project.id.to_proto(),
1547                            worktrees: project.worktrees.clone(),
1548                        },
1549                    )
1550                },
1551            );
1552        }
1553
1554        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1555
1556        let rejoined_room = rejoined_room.into_inner();
1557
1558        room = rejoined_room.room;
1559        channel = rejoined_room.channel;
1560    }
1561
1562    if let Some(channel) = channel {
1563        channel_updated(
1564            &channel,
1565            &room,
1566            &session.peer,
1567            &*session.connection_pool().await,
1568        );
1569    }
1570
1571    update_user_contacts(session.user_id(), &session).await?;
1572    Ok(())
1573}
1574
1575fn notify_rejoined_projects(
1576    rejoined_projects: &mut Vec<RejoinedProject>,
1577    session: &Session,
1578) -> Result<()> {
1579    for project in rejoined_projects.iter() {
1580        for collaborator in &project.collaborators {
1581            session
1582                .peer
1583                .send(
1584                    collaborator.connection_id,
1585                    proto::UpdateProjectCollaborator {
1586                        project_id: project.id.to_proto(),
1587                        old_peer_id: Some(project.old_connection_id.into()),
1588                        new_peer_id: Some(session.connection_id.into()),
1589                    },
1590                )
1591                .trace_err();
1592        }
1593    }
1594
1595    for project in rejoined_projects {
1596        for worktree in mem::take(&mut project.worktrees) {
1597            // Stream this worktree's entries.
1598            let message = proto::UpdateWorktree {
1599                project_id: project.id.to_proto(),
1600                worktree_id: worktree.id,
1601                abs_path: worktree.abs_path.clone(),
1602                root_name: worktree.root_name,
1603                updated_entries: worktree.updated_entries,
1604                removed_entries: worktree.removed_entries,
1605                scan_id: worktree.scan_id,
1606                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1607                updated_repositories: worktree.updated_repositories,
1608                removed_repositories: worktree.removed_repositories,
1609            };
1610            for update in proto::split_worktree_update(message) {
1611                session.peer.send(session.connection_id, update)?;
1612            }
1613
1614            // Stream this worktree's diagnostics.
1615            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1616            if let Some(summary) = worktree_diagnostics.next() {
1617                let message = proto::UpdateDiagnosticSummary {
1618                    project_id: project.id.to_proto(),
1619                    worktree_id: worktree.id,
1620                    summary: Some(summary),
1621                    more_summaries: worktree_diagnostics.collect(),
1622                };
1623                session.peer.send(session.connection_id, message)?;
1624            }
1625
1626            for settings_file in worktree.settings_files {
1627                session.peer.send(
1628                    session.connection_id,
1629                    proto::UpdateWorktreeSettings {
1630                        project_id: project.id.to_proto(),
1631                        worktree_id: worktree.id,
1632                        path: settings_file.path,
1633                        content: Some(settings_file.content),
1634                        kind: Some(settings_file.kind.to_proto().into()),
1635                    },
1636                )?;
1637            }
1638        }
1639
1640        for repository in mem::take(&mut project.updated_repositories) {
1641            for update in split_repository_update(repository) {
1642                session.peer.send(session.connection_id, update)?;
1643            }
1644        }
1645
1646        for id in mem::take(&mut project.removed_repositories) {
1647            session.peer.send(
1648                session.connection_id,
1649                proto::RemoveRepository {
1650                    project_id: project.id.to_proto(),
1651                    id,
1652                },
1653            )?;
1654        }
1655    }
1656
1657    Ok(())
1658}
1659
1660/// leave room disconnects from the room.
1661async fn leave_room(
1662    _: proto::LeaveRoom,
1663    response: Response<proto::LeaveRoom>,
1664    session: MessageContext,
1665) -> Result<()> {
1666    leave_room_for_session(&session, session.connection_id).await?;
1667    response.send(proto::Ack {})?;
1668    Ok(())
1669}
1670
1671/// Updates the permissions of someone else in the room.
1672async fn set_room_participant_role(
1673    request: proto::SetRoomParticipantRole,
1674    response: Response<proto::SetRoomParticipantRole>,
1675    session: MessageContext,
1676) -> Result<()> {
1677    let user_id = UserId::from_proto(request.user_id);
1678    let role = ChannelRole::from(request.role());
1679
1680    let (livekit_room, can_publish) = {
1681        let room = session
1682            .db()
1683            .await
1684            .set_room_participant_role(
1685                session.user_id(),
1686                RoomId::from_proto(request.room_id),
1687                user_id,
1688                role,
1689            )
1690            .await?;
1691
1692        let livekit_room = room.livekit_room.clone();
1693        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1694        room_updated(&room, &session.peer);
1695        (livekit_room, can_publish)
1696    };
1697
1698    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1699        live_kit
1700            .update_participant(
1701                livekit_room.clone(),
1702                request.user_id.to_string(),
1703                livekit_api::proto::ParticipantPermission {
1704                    can_subscribe: true,
1705                    can_publish,
1706                    can_publish_data: can_publish,
1707                    hidden: false,
1708                    recorder: false,
1709                },
1710            )
1711            .await
1712            .trace_err();
1713    }
1714
1715    response.send(proto::Ack {})?;
1716    Ok(())
1717}
1718
1719/// Call someone else into the current room
1720async fn call(
1721    request: proto::Call,
1722    response: Response<proto::Call>,
1723    session: MessageContext,
1724) -> Result<()> {
1725    let room_id = RoomId::from_proto(request.room_id);
1726    let calling_user_id = session.user_id();
1727    let calling_connection_id = session.connection_id;
1728    let called_user_id = UserId::from_proto(request.called_user_id);
1729    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1730    if !session
1731        .db()
1732        .await
1733        .has_contact(calling_user_id, called_user_id)
1734        .await?
1735    {
1736        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1737    }
1738
1739    let incoming_call = {
1740        let (room, incoming_call) = &mut *session
1741            .db()
1742            .await
1743            .call(
1744                room_id,
1745                calling_user_id,
1746                calling_connection_id,
1747                called_user_id,
1748                initial_project_id,
1749            )
1750            .await?;
1751        room_updated(room, &session.peer);
1752        mem::take(incoming_call)
1753    };
1754    update_user_contacts(called_user_id, &session).await?;
1755
1756    let mut calls = session
1757        .connection_pool()
1758        .await
1759        .user_connection_ids(called_user_id)
1760        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1761        .collect::<FuturesUnordered<_>>();
1762
1763    while let Some(call_response) = calls.next().await {
1764        match call_response.as_ref() {
1765            Ok(_) => {
1766                response.send(proto::Ack {})?;
1767                return Ok(());
1768            }
1769            Err(_) => {
1770                call_response.trace_err();
1771            }
1772        }
1773    }
1774
1775    {
1776        let room = session
1777            .db()
1778            .await
1779            .call_failed(room_id, called_user_id)
1780            .await?;
1781        room_updated(&room, &session.peer);
1782    }
1783    update_user_contacts(called_user_id, &session).await?;
1784
1785    Err(anyhow!("failed to ring user"))?
1786}
1787
1788/// Cancel an outgoing call.
1789async fn cancel_call(
1790    request: proto::CancelCall,
1791    response: Response<proto::CancelCall>,
1792    session: MessageContext,
1793) -> Result<()> {
1794    let called_user_id = UserId::from_proto(request.called_user_id);
1795    let room_id = RoomId::from_proto(request.room_id);
1796    {
1797        let room = session
1798            .db()
1799            .await
1800            .cancel_call(room_id, session.connection_id, called_user_id)
1801            .await?;
1802        room_updated(&room, &session.peer);
1803    }
1804
1805    for connection_id in session
1806        .connection_pool()
1807        .await
1808        .user_connection_ids(called_user_id)
1809    {
1810        session
1811            .peer
1812            .send(
1813                connection_id,
1814                proto::CallCanceled {
1815                    room_id: room_id.to_proto(),
1816                },
1817            )
1818            .trace_err();
1819    }
1820    response.send(proto::Ack {})?;
1821
1822    update_user_contacts(called_user_id, &session).await?;
1823    Ok(())
1824}
1825
1826/// Decline an incoming call.
1827async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1828    let room_id = RoomId::from_proto(message.room_id);
1829    {
1830        let room = session
1831            .db()
1832            .await
1833            .decline_call(Some(room_id), session.user_id())
1834            .await?
1835            .context("declining call")?;
1836        room_updated(&room, &session.peer);
1837    }
1838
1839    for connection_id in session
1840        .connection_pool()
1841        .await
1842        .user_connection_ids(session.user_id())
1843    {
1844        session
1845            .peer
1846            .send(
1847                connection_id,
1848                proto::CallCanceled {
1849                    room_id: room_id.to_proto(),
1850                },
1851            )
1852            .trace_err();
1853    }
1854    update_user_contacts(session.user_id(), &session).await?;
1855    Ok(())
1856}
1857
1858/// Updates other participants in the room with your current location.
1859async fn update_participant_location(
1860    request: proto::UpdateParticipantLocation,
1861    response: Response<proto::UpdateParticipantLocation>,
1862    session: MessageContext,
1863) -> Result<()> {
1864    let room_id = RoomId::from_proto(request.room_id);
1865    let location = request.location.context("invalid location")?;
1866
1867    let db = session.db().await;
1868    let room = db
1869        .update_room_participant_location(room_id, session.connection_id, location)
1870        .await?;
1871
1872    room_updated(&room, &session.peer);
1873    response.send(proto::Ack {})?;
1874    Ok(())
1875}
1876
1877/// Share a project into the room.
1878async fn share_project(
1879    request: proto::ShareProject,
1880    response: Response<proto::ShareProject>,
1881    session: MessageContext,
1882) -> Result<()> {
1883    let (project_id, room) = &*session
1884        .db()
1885        .await
1886        .share_project(
1887            RoomId::from_proto(request.room_id),
1888            session.connection_id,
1889            &request.worktrees,
1890            request.is_ssh_project,
1891        )
1892        .await?;
1893    response.send(proto::ShareProjectResponse {
1894        project_id: project_id.to_proto(),
1895    })?;
1896    room_updated(room, &session.peer);
1897
1898    Ok(())
1899}
1900
1901/// Unshare a project from the room.
1902async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1903    let project_id = ProjectId::from_proto(message.project_id);
1904    unshare_project_internal(project_id, session.connection_id, &session).await
1905}
1906
1907async fn unshare_project_internal(
1908    project_id: ProjectId,
1909    connection_id: ConnectionId,
1910    session: &Session,
1911) -> Result<()> {
1912    let delete = {
1913        let room_guard = session
1914            .db()
1915            .await
1916            .unshare_project(project_id, connection_id)
1917            .await?;
1918
1919        let (delete, room, guest_connection_ids) = &*room_guard;
1920
1921        let message = proto::UnshareProject {
1922            project_id: project_id.to_proto(),
1923        };
1924
1925        broadcast(
1926            Some(connection_id),
1927            guest_connection_ids.iter().copied(),
1928            |conn_id| session.peer.send(conn_id, message.clone()),
1929        );
1930        if let Some(room) = room {
1931            room_updated(room, &session.peer);
1932        }
1933
1934        *delete
1935    };
1936
1937    if delete {
1938        let db = session.db().await;
1939        db.delete_project(project_id).await?;
1940    }
1941
1942    Ok(())
1943}
1944
1945/// Join someone elses shared project.
1946async fn join_project(
1947    request: proto::JoinProject,
1948    response: Response<proto::JoinProject>,
1949    session: MessageContext,
1950) -> Result<()> {
1951    let project_id = ProjectId::from_proto(request.project_id);
1952
1953    tracing::info!(%project_id, "join project");
1954
1955    let db = session.db().await;
1956    let (project, replica_id) = &mut *db
1957        .join_project(
1958            project_id,
1959            session.connection_id,
1960            session.user_id(),
1961            request.committer_name.clone(),
1962            request.committer_email.clone(),
1963        )
1964        .await?;
1965    drop(db);
1966    tracing::info!(%project_id, "join remote project");
1967    let collaborators = project
1968        .collaborators
1969        .iter()
1970        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1971        .map(|collaborator| collaborator.to_proto())
1972        .collect::<Vec<_>>();
1973    let project_id = project.id;
1974    let guest_user_id = session.user_id();
1975
1976    let worktrees = project
1977        .worktrees
1978        .iter()
1979        .map(|(id, worktree)| proto::WorktreeMetadata {
1980            id: *id,
1981            root_name: worktree.root_name.clone(),
1982            visible: worktree.visible,
1983            abs_path: worktree.abs_path.clone(),
1984        })
1985        .collect::<Vec<_>>();
1986
1987    let add_project_collaborator = proto::AddProjectCollaborator {
1988        project_id: project_id.to_proto(),
1989        collaborator: Some(proto::Collaborator {
1990            peer_id: Some(session.connection_id.into()),
1991            replica_id: replica_id.0 as u32,
1992            user_id: guest_user_id.to_proto(),
1993            is_host: false,
1994            committer_name: request.committer_name.clone(),
1995            committer_email: request.committer_email.clone(),
1996        }),
1997    };
1998
1999    for collaborator in &collaborators {
2000        session
2001            .peer
2002            .send(
2003                collaborator.peer_id.unwrap().into(),
2004                add_project_collaborator.clone(),
2005            )
2006            .trace_err();
2007    }
2008
2009    // First, we send the metadata associated with each worktree.
2010    let (language_servers, language_server_capabilities) = project
2011        .language_servers
2012        .clone()
2013        .into_iter()
2014        .map(|server| (server.server, server.capabilities))
2015        .unzip();
2016    response.send(proto::JoinProjectResponse {
2017        project_id: project.id.0 as u64,
2018        worktrees,
2019        replica_id: replica_id.0 as u32,
2020        collaborators,
2021        language_servers,
2022        language_server_capabilities,
2023        role: project.role.into(),
2024    })?;
2025
2026    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2027        // Stream this worktree's entries.
2028        let message = proto::UpdateWorktree {
2029            project_id: project_id.to_proto(),
2030            worktree_id,
2031            abs_path: worktree.abs_path.clone(),
2032            root_name: worktree.root_name,
2033            updated_entries: worktree.entries,
2034            removed_entries: Default::default(),
2035            scan_id: worktree.scan_id,
2036            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2037            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2038            removed_repositories: Default::default(),
2039        };
2040        for update in proto::split_worktree_update(message) {
2041            session.peer.send(session.connection_id, update.clone())?;
2042        }
2043
2044        // Stream this worktree's diagnostics.
2045        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2046        if let Some(summary) = worktree_diagnostics.next() {
2047            let message = proto::UpdateDiagnosticSummary {
2048                project_id: project.id.to_proto(),
2049                worktree_id: worktree.id,
2050                summary: Some(summary),
2051                more_summaries: worktree_diagnostics.collect(),
2052            };
2053            session.peer.send(session.connection_id, message)?;
2054        }
2055
2056        for settings_file in worktree.settings_files {
2057            session.peer.send(
2058                session.connection_id,
2059                proto::UpdateWorktreeSettings {
2060                    project_id: project_id.to_proto(),
2061                    worktree_id: worktree.id,
2062                    path: settings_file.path,
2063                    content: Some(settings_file.content),
2064                    kind: Some(settings_file.kind.to_proto() as i32),
2065                },
2066            )?;
2067        }
2068    }
2069
2070    for repository in mem::take(&mut project.repositories) {
2071        for update in split_repository_update(repository) {
2072            session.peer.send(session.connection_id, update)?;
2073        }
2074    }
2075
2076    for language_server in &project.language_servers {
2077        session.peer.send(
2078            session.connection_id,
2079            proto::UpdateLanguageServer {
2080                project_id: project_id.to_proto(),
2081                server_name: Some(language_server.server.name.clone()),
2082                language_server_id: language_server.server.id,
2083                variant: Some(
2084                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2085                        proto::LspDiskBasedDiagnosticsUpdated {},
2086                    ),
2087                ),
2088            },
2089        )?;
2090    }
2091
2092    Ok(())
2093}
2094
2095/// Leave someone elses shared project.
2096async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2097    let sender_id = session.connection_id;
2098    let project_id = ProjectId::from_proto(request.project_id);
2099    let db = session.db().await;
2100
2101    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2102    tracing::info!(
2103        %project_id,
2104        "leave project"
2105    );
2106
2107    project_left(project, &session);
2108    if let Some(room) = room {
2109        room_updated(room, &session.peer);
2110    }
2111
2112    Ok(())
2113}
2114
2115/// Updates other participants with changes to the project
2116async fn update_project(
2117    request: proto::UpdateProject,
2118    response: Response<proto::UpdateProject>,
2119    session: MessageContext,
2120) -> Result<()> {
2121    let project_id = ProjectId::from_proto(request.project_id);
2122    let (room, guest_connection_ids) = &*session
2123        .db()
2124        .await
2125        .update_project(project_id, session.connection_id, &request.worktrees)
2126        .await?;
2127    broadcast(
2128        Some(session.connection_id),
2129        guest_connection_ids.iter().copied(),
2130        |connection_id| {
2131            session
2132                .peer
2133                .forward_send(session.connection_id, connection_id, request.clone())
2134        },
2135    );
2136    if let Some(room) = room {
2137        room_updated(room, &session.peer);
2138    }
2139    response.send(proto::Ack {})?;
2140
2141    Ok(())
2142}
2143
2144/// Updates other participants with changes to the worktree
2145async fn update_worktree(
2146    request: proto::UpdateWorktree,
2147    response: Response<proto::UpdateWorktree>,
2148    session: MessageContext,
2149) -> Result<()> {
2150    let guest_connection_ids = session
2151        .db()
2152        .await
2153        .update_worktree(&request, session.connection_id)
2154        .await?;
2155
2156    broadcast(
2157        Some(session.connection_id),
2158        guest_connection_ids.iter().copied(),
2159        |connection_id| {
2160            session
2161                .peer
2162                .forward_send(session.connection_id, connection_id, request.clone())
2163        },
2164    );
2165    response.send(proto::Ack {})?;
2166    Ok(())
2167}
2168
2169async fn update_repository(
2170    request: proto::UpdateRepository,
2171    response: Response<proto::UpdateRepository>,
2172    session: MessageContext,
2173) -> Result<()> {
2174    let guest_connection_ids = session
2175        .db()
2176        .await
2177        .update_repository(&request, session.connection_id)
2178        .await?;
2179
2180    broadcast(
2181        Some(session.connection_id),
2182        guest_connection_ids.iter().copied(),
2183        |connection_id| {
2184            session
2185                .peer
2186                .forward_send(session.connection_id, connection_id, request.clone())
2187        },
2188    );
2189    response.send(proto::Ack {})?;
2190    Ok(())
2191}
2192
2193async fn remove_repository(
2194    request: proto::RemoveRepository,
2195    response: Response<proto::RemoveRepository>,
2196    session: MessageContext,
2197) -> Result<()> {
2198    let guest_connection_ids = session
2199        .db()
2200        .await
2201        .remove_repository(&request, session.connection_id)
2202        .await?;
2203
2204    broadcast(
2205        Some(session.connection_id),
2206        guest_connection_ids.iter().copied(),
2207        |connection_id| {
2208            session
2209                .peer
2210                .forward_send(session.connection_id, connection_id, request.clone())
2211        },
2212    );
2213    response.send(proto::Ack {})?;
2214    Ok(())
2215}
2216
2217/// Updates other participants with changes to the diagnostics
2218async fn update_diagnostic_summary(
2219    message: proto::UpdateDiagnosticSummary,
2220    session: MessageContext,
2221) -> Result<()> {
2222    let guest_connection_ids = session
2223        .db()
2224        .await
2225        .update_diagnostic_summary(&message, session.connection_id)
2226        .await?;
2227
2228    broadcast(
2229        Some(session.connection_id),
2230        guest_connection_ids.iter().copied(),
2231        |connection_id| {
2232            session
2233                .peer
2234                .forward_send(session.connection_id, connection_id, message.clone())
2235        },
2236    );
2237
2238    Ok(())
2239}
2240
2241/// Updates other participants with changes to the worktree settings
2242async fn update_worktree_settings(
2243    message: proto::UpdateWorktreeSettings,
2244    session: MessageContext,
2245) -> Result<()> {
2246    let guest_connection_ids = session
2247        .db()
2248        .await
2249        .update_worktree_settings(&message, session.connection_id)
2250        .await?;
2251
2252    broadcast(
2253        Some(session.connection_id),
2254        guest_connection_ids.iter().copied(),
2255        |connection_id| {
2256            session
2257                .peer
2258                .forward_send(session.connection_id, connection_id, message.clone())
2259        },
2260    );
2261
2262    Ok(())
2263}
2264
2265/// Notify other participants that a language server has started.
2266async fn start_language_server(
2267    request: proto::StartLanguageServer,
2268    session: MessageContext,
2269) -> Result<()> {
2270    let guest_connection_ids = session
2271        .db()
2272        .await
2273        .start_language_server(&request, session.connection_id)
2274        .await?;
2275
2276    broadcast(
2277        Some(session.connection_id),
2278        guest_connection_ids.iter().copied(),
2279        |connection_id| {
2280            session
2281                .peer
2282                .forward_send(session.connection_id, connection_id, request.clone())
2283        },
2284    );
2285    Ok(())
2286}
2287
2288/// Notify other participants that a language server has changed.
2289async fn update_language_server(
2290    request: proto::UpdateLanguageServer,
2291    session: MessageContext,
2292) -> Result<()> {
2293    let project_id = ProjectId::from_proto(request.project_id);
2294    let db = session.db().await;
2295
2296    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2297        && let Some(capabilities) = update.capabilities.clone()
2298    {
2299        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2300            .await?;
2301    }
2302
2303    let project_connection_ids = db
2304        .project_connection_ids(project_id, session.connection_id, true)
2305        .await?;
2306    broadcast(
2307        Some(session.connection_id),
2308        project_connection_ids.iter().copied(),
2309        |connection_id| {
2310            session
2311                .peer
2312                .forward_send(session.connection_id, connection_id, request.clone())
2313        },
2314    );
2315    Ok(())
2316}
2317
2318/// forward a project request to the host. These requests should be read only
2319/// as guests are allowed to send them.
2320async fn forward_read_only_project_request<T>(
2321    request: T,
2322    response: Response<T>,
2323    session: MessageContext,
2324) -> Result<()>
2325where
2326    T: EntityMessage + RequestMessage,
2327{
2328    let project_id = ProjectId::from_proto(request.remote_entity_id());
2329    let host_connection_id = session
2330        .db()
2331        .await
2332        .host_for_read_only_project_request(project_id, session.connection_id)
2333        .await?;
2334    let payload = session.forward_request(host_connection_id, request).await?;
2335    response.send(payload)?;
2336    Ok(())
2337}
2338
2339/// forward a project request to the host. These requests are disallowed
2340/// for guests.
2341async fn forward_mutating_project_request<T>(
2342    request: T,
2343    response: Response<T>,
2344    session: MessageContext,
2345) -> Result<()>
2346where
2347    T: EntityMessage + RequestMessage,
2348{
2349    let project_id = ProjectId::from_proto(request.remote_entity_id());
2350
2351    let host_connection_id = session
2352        .db()
2353        .await
2354        .host_for_mutating_project_request(project_id, session.connection_id)
2355        .await?;
2356    let payload = session.forward_request(host_connection_id, request).await?;
2357    response.send(payload)?;
2358    Ok(())
2359}
2360
2361// todo(lsp) remove after Zed Stable hits v0.204.x
2362async fn multi_lsp_query(
2363    request: MultiLspQuery,
2364    response: Response<MultiLspQuery>,
2365    session: MessageContext,
2366) -> Result<()> {
2367    tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2368    tracing::info!("multi_lsp_query message received");
2369    forward_mutating_project_request(request, response, session).await
2370}
2371
2372async fn lsp_query(
2373    request: proto::LspQuery,
2374    response: Response<proto::LspQuery>,
2375    session: MessageContext,
2376) -> Result<()> {
2377    let (name, should_write) = request.query_name_and_write_permissions();
2378    tracing::Span::current().record("lsp_query_request", name);
2379    tracing::info!("lsp_query message received");
2380    if should_write {
2381        forward_mutating_project_request(request, response, session).await
2382    } else {
2383        forward_read_only_project_request(request, response, session).await
2384    }
2385}
2386
2387/// Notify other participants that a new buffer has been created
2388async fn create_buffer_for_peer(
2389    request: proto::CreateBufferForPeer,
2390    session: MessageContext,
2391) -> Result<()> {
2392    session
2393        .db()
2394        .await
2395        .check_user_is_project_host(
2396            ProjectId::from_proto(request.project_id),
2397            session.connection_id,
2398        )
2399        .await?;
2400    let peer_id = request.peer_id.context("invalid peer id")?;
2401    session
2402        .peer
2403        .forward_send(session.connection_id, peer_id.into(), request)?;
2404    Ok(())
2405}
2406
2407/// Notify other participants that a buffer has been updated. This is
2408/// allowed for guests as long as the update is limited to selections.
2409async fn update_buffer(
2410    request: proto::UpdateBuffer,
2411    response: Response<proto::UpdateBuffer>,
2412    session: MessageContext,
2413) -> Result<()> {
2414    let project_id = ProjectId::from_proto(request.project_id);
2415    let mut capability = Capability::ReadOnly;
2416
2417    for op in request.operations.iter() {
2418        match op.variant {
2419            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2420            Some(_) => capability = Capability::ReadWrite,
2421        }
2422    }
2423
2424    let host = {
2425        let guard = session
2426            .db()
2427            .await
2428            .connections_for_buffer_update(project_id, session.connection_id, capability)
2429            .await?;
2430
2431        let (host, guests) = &*guard;
2432
2433        broadcast(
2434            Some(session.connection_id),
2435            guests.clone(),
2436            |connection_id| {
2437                session
2438                    .peer
2439                    .forward_send(session.connection_id, connection_id, request.clone())
2440            },
2441        );
2442
2443        *host
2444    };
2445
2446    if host != session.connection_id {
2447        session.forward_request(host, request.clone()).await?;
2448    }
2449
2450    response.send(proto::Ack {})?;
2451    Ok(())
2452}
2453
2454async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2455    let project_id = ProjectId::from_proto(message.project_id);
2456
2457    let operation = message.operation.as_ref().context("invalid operation")?;
2458    let capability = match operation.variant.as_ref() {
2459        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2460            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2461                match buffer_op.variant {
2462                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2463                        Capability::ReadOnly
2464                    }
2465                    _ => Capability::ReadWrite,
2466                }
2467            } else {
2468                Capability::ReadWrite
2469            }
2470        }
2471        Some(_) => Capability::ReadWrite,
2472        None => Capability::ReadOnly,
2473    };
2474
2475    let guard = session
2476        .db()
2477        .await
2478        .connections_for_buffer_update(project_id, session.connection_id, capability)
2479        .await?;
2480
2481    let (host, guests) = &*guard;
2482
2483    broadcast(
2484        Some(session.connection_id),
2485        guests.iter().chain([host]).copied(),
2486        |connection_id| {
2487            session
2488                .peer
2489                .forward_send(session.connection_id, connection_id, message.clone())
2490        },
2491    );
2492
2493    Ok(())
2494}
2495
2496/// Notify other participants that a project has been updated.
2497async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2498    request: T,
2499    session: MessageContext,
2500) -> Result<()> {
2501    let project_id = ProjectId::from_proto(request.remote_entity_id());
2502    let project_connection_ids = session
2503        .db()
2504        .await
2505        .project_connection_ids(project_id, session.connection_id, false)
2506        .await?;
2507
2508    broadcast(
2509        Some(session.connection_id),
2510        project_connection_ids.iter().copied(),
2511        |connection_id| {
2512            session
2513                .peer
2514                .forward_send(session.connection_id, connection_id, request.clone())
2515        },
2516    );
2517    Ok(())
2518}
2519
2520/// Start following another user in a call.
2521async fn follow(
2522    request: proto::Follow,
2523    response: Response<proto::Follow>,
2524    session: MessageContext,
2525) -> Result<()> {
2526    let room_id = RoomId::from_proto(request.room_id);
2527    let project_id = request.project_id.map(ProjectId::from_proto);
2528    let leader_id = request.leader_id.context("invalid leader id")?.into();
2529    let follower_id = session.connection_id;
2530
2531    session
2532        .db()
2533        .await
2534        .check_room_participants(room_id, leader_id, session.connection_id)
2535        .await?;
2536
2537    let response_payload = session.forward_request(leader_id, request).await?;
2538    response.send(response_payload)?;
2539
2540    if let Some(project_id) = project_id {
2541        let room = session
2542            .db()
2543            .await
2544            .follow(room_id, project_id, leader_id, follower_id)
2545            .await?;
2546        room_updated(&room, &session.peer);
2547    }
2548
2549    Ok(())
2550}
2551
2552/// Stop following another user in a call.
2553async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2554    let room_id = RoomId::from_proto(request.room_id);
2555    let project_id = request.project_id.map(ProjectId::from_proto);
2556    let leader_id = request.leader_id.context("invalid leader id")?.into();
2557    let follower_id = session.connection_id;
2558
2559    session
2560        .db()
2561        .await
2562        .check_room_participants(room_id, leader_id, session.connection_id)
2563        .await?;
2564
2565    session
2566        .peer
2567        .forward_send(session.connection_id, leader_id, request)?;
2568
2569    if let Some(project_id) = project_id {
2570        let room = session
2571            .db()
2572            .await
2573            .unfollow(room_id, project_id, leader_id, follower_id)
2574            .await?;
2575        room_updated(&room, &session.peer);
2576    }
2577
2578    Ok(())
2579}
2580
2581/// Notify everyone following you of your current location.
2582async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2583    let room_id = RoomId::from_proto(request.room_id);
2584    let database = session.db.lock().await;
2585
2586    let connection_ids = if let Some(project_id) = request.project_id {
2587        let project_id = ProjectId::from_proto(project_id);
2588        database
2589            .project_connection_ids(project_id, session.connection_id, true)
2590            .await?
2591    } else {
2592        database
2593            .room_connection_ids(room_id, session.connection_id)
2594            .await?
2595    };
2596
2597    // For now, don't send view update messages back to that view's current leader.
2598    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2599        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2600        _ => None,
2601    });
2602
2603    for connection_id in connection_ids.iter().cloned() {
2604        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2605            session
2606                .peer
2607                .forward_send(session.connection_id, connection_id, request.clone())?;
2608        }
2609    }
2610    Ok(())
2611}
2612
2613/// Get public data about users.
2614async fn get_users(
2615    request: proto::GetUsers,
2616    response: Response<proto::GetUsers>,
2617    session: MessageContext,
2618) -> Result<()> {
2619    let user_ids = request
2620        .user_ids
2621        .into_iter()
2622        .map(UserId::from_proto)
2623        .collect();
2624    let users = session
2625        .db()
2626        .await
2627        .get_users_by_ids(user_ids)
2628        .await?
2629        .into_iter()
2630        .map(|user| proto::User {
2631            id: user.id.to_proto(),
2632            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2633            github_login: user.github_login,
2634            name: user.name,
2635        })
2636        .collect();
2637    response.send(proto::UsersResponse { users })?;
2638    Ok(())
2639}
2640
2641/// Search for users (to invite) buy Github login
2642async fn fuzzy_search_users(
2643    request: proto::FuzzySearchUsers,
2644    response: Response<proto::FuzzySearchUsers>,
2645    session: MessageContext,
2646) -> Result<()> {
2647    let query = request.query;
2648    let users = match query.len() {
2649        0 => vec![],
2650        1 | 2 => session
2651            .db()
2652            .await
2653            .get_user_by_github_login(&query)
2654            .await?
2655            .into_iter()
2656            .collect(),
2657        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2658    };
2659    let users = users
2660        .into_iter()
2661        .filter(|user| user.id != session.user_id())
2662        .map(|user| proto::User {
2663            id: user.id.to_proto(),
2664            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2665            github_login: user.github_login,
2666            name: user.name,
2667        })
2668        .collect();
2669    response.send(proto::UsersResponse { users })?;
2670    Ok(())
2671}
2672
2673/// Send a contact request to another user.
2674async fn request_contact(
2675    request: proto::RequestContact,
2676    response: Response<proto::RequestContact>,
2677    session: MessageContext,
2678) -> Result<()> {
2679    let requester_id = session.user_id();
2680    let responder_id = UserId::from_proto(request.responder_id);
2681    if requester_id == responder_id {
2682        return Err(anyhow!("cannot add yourself as a contact"))?;
2683    }
2684
2685    let notifications = session
2686        .db()
2687        .await
2688        .send_contact_request(requester_id, responder_id)
2689        .await?;
2690
2691    // Update outgoing contact requests of requester
2692    let mut update = proto::UpdateContacts::default();
2693    update.outgoing_requests.push(responder_id.to_proto());
2694    for connection_id in session
2695        .connection_pool()
2696        .await
2697        .user_connection_ids(requester_id)
2698    {
2699        session.peer.send(connection_id, update.clone())?;
2700    }
2701
2702    // Update incoming contact requests of responder
2703    let mut update = proto::UpdateContacts::default();
2704    update
2705        .incoming_requests
2706        .push(proto::IncomingContactRequest {
2707            requester_id: requester_id.to_proto(),
2708        });
2709    let connection_pool = session.connection_pool().await;
2710    for connection_id in connection_pool.user_connection_ids(responder_id) {
2711        session.peer.send(connection_id, update.clone())?;
2712    }
2713
2714    send_notifications(&connection_pool, &session.peer, notifications);
2715
2716    response.send(proto::Ack {})?;
2717    Ok(())
2718}
2719
2720/// Accept or decline a contact request
2721async fn respond_to_contact_request(
2722    request: proto::RespondToContactRequest,
2723    response: Response<proto::RespondToContactRequest>,
2724    session: MessageContext,
2725) -> Result<()> {
2726    let responder_id = session.user_id();
2727    let requester_id = UserId::from_proto(request.requester_id);
2728    let db = session.db().await;
2729    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2730        db.dismiss_contact_notification(responder_id, requester_id)
2731            .await?;
2732    } else {
2733        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2734
2735        let notifications = db
2736            .respond_to_contact_request(responder_id, requester_id, accept)
2737            .await?;
2738        let requester_busy = db.is_user_busy(requester_id).await?;
2739        let responder_busy = db.is_user_busy(responder_id).await?;
2740
2741        let pool = session.connection_pool().await;
2742        // Update responder with new contact
2743        let mut update = proto::UpdateContacts::default();
2744        if accept {
2745            update
2746                .contacts
2747                .push(contact_for_user(requester_id, requester_busy, &pool));
2748        }
2749        update
2750            .remove_incoming_requests
2751            .push(requester_id.to_proto());
2752        for connection_id in pool.user_connection_ids(responder_id) {
2753            session.peer.send(connection_id, update.clone())?;
2754        }
2755
2756        // Update requester with new contact
2757        let mut update = proto::UpdateContacts::default();
2758        if accept {
2759            update
2760                .contacts
2761                .push(contact_for_user(responder_id, responder_busy, &pool));
2762        }
2763        update
2764            .remove_outgoing_requests
2765            .push(responder_id.to_proto());
2766
2767        for connection_id in pool.user_connection_ids(requester_id) {
2768            session.peer.send(connection_id, update.clone())?;
2769        }
2770
2771        send_notifications(&pool, &session.peer, notifications);
2772    }
2773
2774    response.send(proto::Ack {})?;
2775    Ok(())
2776}
2777
2778/// Remove a contact.
2779async fn remove_contact(
2780    request: proto::RemoveContact,
2781    response: Response<proto::RemoveContact>,
2782    session: MessageContext,
2783) -> Result<()> {
2784    let requester_id = session.user_id();
2785    let responder_id = UserId::from_proto(request.user_id);
2786    let db = session.db().await;
2787    let (contact_accepted, deleted_notification_id) =
2788        db.remove_contact(requester_id, responder_id).await?;
2789
2790    let pool = session.connection_pool().await;
2791    // Update outgoing contact requests of requester
2792    let mut update = proto::UpdateContacts::default();
2793    if contact_accepted {
2794        update.remove_contacts.push(responder_id.to_proto());
2795    } else {
2796        update
2797            .remove_outgoing_requests
2798            .push(responder_id.to_proto());
2799    }
2800    for connection_id in pool.user_connection_ids(requester_id) {
2801        session.peer.send(connection_id, update.clone())?;
2802    }
2803
2804    // Update incoming contact requests of responder
2805    let mut update = proto::UpdateContacts::default();
2806    if contact_accepted {
2807        update.remove_contacts.push(requester_id.to_proto());
2808    } else {
2809        update
2810            .remove_incoming_requests
2811            .push(requester_id.to_proto());
2812    }
2813    for connection_id in pool.user_connection_ids(responder_id) {
2814        session.peer.send(connection_id, update.clone())?;
2815        if let Some(notification_id) = deleted_notification_id {
2816            session.peer.send(
2817                connection_id,
2818                proto::DeleteNotification {
2819                    notification_id: notification_id.to_proto(),
2820                },
2821            )?;
2822        }
2823    }
2824
2825    response.send(proto::Ack {})?;
2826    Ok(())
2827}
2828
2829fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2830    version.0.minor() < 139
2831}
2832
2833async fn subscribe_to_channels(
2834    _: proto::SubscribeToChannels,
2835    session: MessageContext,
2836) -> Result<()> {
2837    subscribe_user_to_channels(session.user_id(), &session).await?;
2838    Ok(())
2839}
2840
2841async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2842    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2843    let mut pool = session.connection_pool().await;
2844    for membership in &channels_for_user.channel_memberships {
2845        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2846    }
2847    session.peer.send(
2848        session.connection_id,
2849        build_update_user_channels(&channels_for_user),
2850    )?;
2851    session.peer.send(
2852        session.connection_id,
2853        build_channels_update(channels_for_user),
2854    )?;
2855    Ok(())
2856}
2857
2858/// Creates a new channel.
2859async fn create_channel(
2860    request: proto::CreateChannel,
2861    response: Response<proto::CreateChannel>,
2862    session: MessageContext,
2863) -> Result<()> {
2864    let db = session.db().await;
2865
2866    let parent_id = request.parent_id.map(ChannelId::from_proto);
2867    let (channel, membership) = db
2868        .create_channel(&request.name, parent_id, session.user_id())
2869        .await?;
2870
2871    let root_id = channel.root_id();
2872    let channel = Channel::from_model(channel);
2873
2874    response.send(proto::CreateChannelResponse {
2875        channel: Some(channel.to_proto()),
2876        parent_id: request.parent_id,
2877    })?;
2878
2879    let mut connection_pool = session.connection_pool().await;
2880    if let Some(membership) = membership {
2881        connection_pool.subscribe_to_channel(
2882            membership.user_id,
2883            membership.channel_id,
2884            membership.role,
2885        );
2886        let update = proto::UpdateUserChannels {
2887            channel_memberships: vec![proto::ChannelMembership {
2888                channel_id: membership.channel_id.to_proto(),
2889                role: membership.role.into(),
2890            }],
2891            ..Default::default()
2892        };
2893        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2894            session.peer.send(connection_id, update.clone())?;
2895        }
2896    }
2897
2898    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2899        if !role.can_see_channel(channel.visibility) {
2900            continue;
2901        }
2902
2903        let update = proto::UpdateChannels {
2904            channels: vec![channel.to_proto()],
2905            ..Default::default()
2906        };
2907        session.peer.send(connection_id, update.clone())?;
2908    }
2909
2910    Ok(())
2911}
2912
2913/// Delete a channel
2914async fn delete_channel(
2915    request: proto::DeleteChannel,
2916    response: Response<proto::DeleteChannel>,
2917    session: MessageContext,
2918) -> Result<()> {
2919    let db = session.db().await;
2920
2921    let channel_id = request.channel_id;
2922    let (root_channel, removed_channels) = db
2923        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2924        .await?;
2925    response.send(proto::Ack {})?;
2926
2927    // Notify members of removed channels
2928    let mut update = proto::UpdateChannels::default();
2929    update
2930        .delete_channels
2931        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2932
2933    let connection_pool = session.connection_pool().await;
2934    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2935        session.peer.send(connection_id, update.clone())?;
2936    }
2937
2938    Ok(())
2939}
2940
2941/// Invite someone to join a channel.
2942async fn invite_channel_member(
2943    request: proto::InviteChannelMember,
2944    response: Response<proto::InviteChannelMember>,
2945    session: MessageContext,
2946) -> Result<()> {
2947    let db = session.db().await;
2948    let channel_id = ChannelId::from_proto(request.channel_id);
2949    let invitee_id = UserId::from_proto(request.user_id);
2950    let InviteMemberResult {
2951        channel,
2952        notifications,
2953    } = db
2954        .invite_channel_member(
2955            channel_id,
2956            invitee_id,
2957            session.user_id(),
2958            request.role().into(),
2959        )
2960        .await?;
2961
2962    let update = proto::UpdateChannels {
2963        channel_invitations: vec![channel.to_proto()],
2964        ..Default::default()
2965    };
2966
2967    let connection_pool = session.connection_pool().await;
2968    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2969        session.peer.send(connection_id, update.clone())?;
2970    }
2971
2972    send_notifications(&connection_pool, &session.peer, notifications);
2973
2974    response.send(proto::Ack {})?;
2975    Ok(())
2976}
2977
2978/// remove someone from a channel
2979async fn remove_channel_member(
2980    request: proto::RemoveChannelMember,
2981    response: Response<proto::RemoveChannelMember>,
2982    session: MessageContext,
2983) -> Result<()> {
2984    let db = session.db().await;
2985    let channel_id = ChannelId::from_proto(request.channel_id);
2986    let member_id = UserId::from_proto(request.user_id);
2987
2988    let RemoveChannelMemberResult {
2989        membership_update,
2990        notification_id,
2991    } = db
2992        .remove_channel_member(channel_id, member_id, session.user_id())
2993        .await?;
2994
2995    let mut connection_pool = session.connection_pool().await;
2996    notify_membership_updated(
2997        &mut connection_pool,
2998        membership_update,
2999        member_id,
3000        &session.peer,
3001    );
3002    for connection_id in connection_pool.user_connection_ids(member_id) {
3003        if let Some(notification_id) = notification_id {
3004            session
3005                .peer
3006                .send(
3007                    connection_id,
3008                    proto::DeleteNotification {
3009                        notification_id: notification_id.to_proto(),
3010                    },
3011                )
3012                .trace_err();
3013        }
3014    }
3015
3016    response.send(proto::Ack {})?;
3017    Ok(())
3018}
3019
3020/// Toggle the channel between public and private.
3021/// Care is taken to maintain the invariant that public channels only descend from public channels,
3022/// (though members-only channels can appear at any point in the hierarchy).
3023async fn set_channel_visibility(
3024    request: proto::SetChannelVisibility,
3025    response: Response<proto::SetChannelVisibility>,
3026    session: MessageContext,
3027) -> Result<()> {
3028    let db = session.db().await;
3029    let channel_id = ChannelId::from_proto(request.channel_id);
3030    let visibility = request.visibility().into();
3031
3032    let channel_model = db
3033        .set_channel_visibility(channel_id, visibility, session.user_id())
3034        .await?;
3035    let root_id = channel_model.root_id();
3036    let channel = Channel::from_model(channel_model);
3037
3038    let mut connection_pool = session.connection_pool().await;
3039    for (user_id, role) in connection_pool
3040        .channel_user_ids(root_id)
3041        .collect::<Vec<_>>()
3042        .into_iter()
3043    {
3044        let update = if role.can_see_channel(channel.visibility) {
3045            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3046            proto::UpdateChannels {
3047                channels: vec![channel.to_proto()],
3048                ..Default::default()
3049            }
3050        } else {
3051            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3052            proto::UpdateChannels {
3053                delete_channels: vec![channel.id.to_proto()],
3054                ..Default::default()
3055            }
3056        };
3057
3058        for connection_id in connection_pool.user_connection_ids(user_id) {
3059            session.peer.send(connection_id, update.clone())?;
3060        }
3061    }
3062
3063    response.send(proto::Ack {})?;
3064    Ok(())
3065}
3066
3067/// Alter the role for a user in the channel.
3068async fn set_channel_member_role(
3069    request: proto::SetChannelMemberRole,
3070    response: Response<proto::SetChannelMemberRole>,
3071    session: MessageContext,
3072) -> Result<()> {
3073    let db = session.db().await;
3074    let channel_id = ChannelId::from_proto(request.channel_id);
3075    let member_id = UserId::from_proto(request.user_id);
3076    let result = db
3077        .set_channel_member_role(
3078            channel_id,
3079            session.user_id(),
3080            member_id,
3081            request.role().into(),
3082        )
3083        .await?;
3084
3085    match result {
3086        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3087            let mut connection_pool = session.connection_pool().await;
3088            notify_membership_updated(
3089                &mut connection_pool,
3090                membership_update,
3091                member_id,
3092                &session.peer,
3093            )
3094        }
3095        db::SetMemberRoleResult::InviteUpdated(channel) => {
3096            let update = proto::UpdateChannels {
3097                channel_invitations: vec![channel.to_proto()],
3098                ..Default::default()
3099            };
3100
3101            for connection_id in session
3102                .connection_pool()
3103                .await
3104                .user_connection_ids(member_id)
3105            {
3106                session.peer.send(connection_id, update.clone())?;
3107            }
3108        }
3109    }
3110
3111    response.send(proto::Ack {})?;
3112    Ok(())
3113}
3114
3115/// Change the name of a channel
3116async fn rename_channel(
3117    request: proto::RenameChannel,
3118    response: Response<proto::RenameChannel>,
3119    session: MessageContext,
3120) -> Result<()> {
3121    let db = session.db().await;
3122    let channel_id = ChannelId::from_proto(request.channel_id);
3123    let channel_model = db
3124        .rename_channel(channel_id, session.user_id(), &request.name)
3125        .await?;
3126    let root_id = channel_model.root_id();
3127    let channel = Channel::from_model(channel_model);
3128
3129    response.send(proto::RenameChannelResponse {
3130        channel: Some(channel.to_proto()),
3131    })?;
3132
3133    let connection_pool = session.connection_pool().await;
3134    let update = proto::UpdateChannels {
3135        channels: vec![channel.to_proto()],
3136        ..Default::default()
3137    };
3138    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3139        if role.can_see_channel(channel.visibility) {
3140            session.peer.send(connection_id, update.clone())?;
3141        }
3142    }
3143
3144    Ok(())
3145}
3146
3147/// Move a channel to a new parent.
3148async fn move_channel(
3149    request: proto::MoveChannel,
3150    response: Response<proto::MoveChannel>,
3151    session: MessageContext,
3152) -> Result<()> {
3153    let channel_id = ChannelId::from_proto(request.channel_id);
3154    let to = ChannelId::from_proto(request.to);
3155
3156    let (root_id, channels) = session
3157        .db()
3158        .await
3159        .move_channel(channel_id, to, session.user_id())
3160        .await?;
3161
3162    let connection_pool = session.connection_pool().await;
3163    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3164        let channels = channels
3165            .iter()
3166            .filter_map(|channel| {
3167                if role.can_see_channel(channel.visibility) {
3168                    Some(channel.to_proto())
3169                } else {
3170                    None
3171                }
3172            })
3173            .collect::<Vec<_>>();
3174        if channels.is_empty() {
3175            continue;
3176        }
3177
3178        let update = proto::UpdateChannels {
3179            channels,
3180            ..Default::default()
3181        };
3182
3183        session.peer.send(connection_id, update.clone())?;
3184    }
3185
3186    response.send(Ack {})?;
3187    Ok(())
3188}
3189
3190async fn reorder_channel(
3191    request: proto::ReorderChannel,
3192    response: Response<proto::ReorderChannel>,
3193    session: MessageContext,
3194) -> Result<()> {
3195    let channel_id = ChannelId::from_proto(request.channel_id);
3196    let direction = request.direction();
3197
3198    let updated_channels = session
3199        .db()
3200        .await
3201        .reorder_channel(channel_id, direction, session.user_id())
3202        .await?;
3203
3204    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3205        let connection_pool = session.connection_pool().await;
3206        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3207            let channels = updated_channels
3208                .iter()
3209                .filter_map(|channel| {
3210                    if role.can_see_channel(channel.visibility) {
3211                        Some(channel.to_proto())
3212                    } else {
3213                        None
3214                    }
3215                })
3216                .collect::<Vec<_>>();
3217
3218            if channels.is_empty() {
3219                continue;
3220            }
3221
3222            let update = proto::UpdateChannels {
3223                channels,
3224                ..Default::default()
3225            };
3226
3227            session.peer.send(connection_id, update.clone())?;
3228        }
3229    }
3230
3231    response.send(Ack {})?;
3232    Ok(())
3233}
3234
3235/// Get the list of channel members
3236async fn get_channel_members(
3237    request: proto::GetChannelMembers,
3238    response: Response<proto::GetChannelMembers>,
3239    session: MessageContext,
3240) -> Result<()> {
3241    let db = session.db().await;
3242    let channel_id = ChannelId::from_proto(request.channel_id);
3243    let limit = if request.limit == 0 {
3244        u16::MAX as u64
3245    } else {
3246        request.limit
3247    };
3248    let (members, users) = db
3249        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3250        .await?;
3251    response.send(proto::GetChannelMembersResponse { members, users })?;
3252    Ok(())
3253}
3254
3255/// Accept or decline a channel invitation.
3256async fn respond_to_channel_invite(
3257    request: proto::RespondToChannelInvite,
3258    response: Response<proto::RespondToChannelInvite>,
3259    session: MessageContext,
3260) -> Result<()> {
3261    let db = session.db().await;
3262    let channel_id = ChannelId::from_proto(request.channel_id);
3263    let RespondToChannelInvite {
3264        membership_update,
3265        notifications,
3266    } = db
3267        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3268        .await?;
3269
3270    let mut connection_pool = session.connection_pool().await;
3271    if let Some(membership_update) = membership_update {
3272        notify_membership_updated(
3273            &mut connection_pool,
3274            membership_update,
3275            session.user_id(),
3276            &session.peer,
3277        );
3278    } else {
3279        let update = proto::UpdateChannels {
3280            remove_channel_invitations: vec![channel_id.to_proto()],
3281            ..Default::default()
3282        };
3283
3284        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3285            session.peer.send(connection_id, update.clone())?;
3286        }
3287    };
3288
3289    send_notifications(&connection_pool, &session.peer, notifications);
3290
3291    response.send(proto::Ack {})?;
3292
3293    Ok(())
3294}
3295
3296/// Join the channels' room
3297async fn join_channel(
3298    request: proto::JoinChannel,
3299    response: Response<proto::JoinChannel>,
3300    session: MessageContext,
3301) -> Result<()> {
3302    let channel_id = ChannelId::from_proto(request.channel_id);
3303    join_channel_internal(channel_id, Box::new(response), session).await
3304}
3305
3306trait JoinChannelInternalResponse {
3307    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3308}
3309impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3310    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3311        Response::<proto::JoinChannel>::send(self, result)
3312    }
3313}
3314impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3315    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3316        Response::<proto::JoinRoom>::send(self, result)
3317    }
3318}
3319
3320async fn join_channel_internal(
3321    channel_id: ChannelId,
3322    response: Box<impl JoinChannelInternalResponse>,
3323    session: MessageContext,
3324) -> Result<()> {
3325    let joined_room = {
3326        let mut db = session.db().await;
3327        // If zed quits without leaving the room, and the user re-opens zed before the
3328        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3329        // room they were in.
3330        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3331            tracing::info!(
3332                stale_connection_id = %connection,
3333                "cleaning up stale connection",
3334            );
3335            drop(db);
3336            leave_room_for_session(&session, connection).await?;
3337            db = session.db().await;
3338        }
3339
3340        let (joined_room, membership_updated, role) = db
3341            .join_channel(channel_id, session.user_id(), session.connection_id)
3342            .await?;
3343
3344        let live_kit_connection_info =
3345            session
3346                .app_state
3347                .livekit_client
3348                .as_ref()
3349                .and_then(|live_kit| {
3350                    let (can_publish, token) = if role == ChannelRole::Guest {
3351                        (
3352                            false,
3353                            live_kit
3354                                .guest_token(
3355                                    &joined_room.room.livekit_room,
3356                                    &session.user_id().to_string(),
3357                                )
3358                                .trace_err()?,
3359                        )
3360                    } else {
3361                        (
3362                            true,
3363                            live_kit
3364                                .room_token(
3365                                    &joined_room.room.livekit_room,
3366                                    &session.user_id().to_string(),
3367                                )
3368                                .trace_err()?,
3369                        )
3370                    };
3371
3372                    Some(LiveKitConnectionInfo {
3373                        server_url: live_kit.url().into(),
3374                        token,
3375                        can_publish,
3376                    })
3377                });
3378
3379        response.send(proto::JoinRoomResponse {
3380            room: Some(joined_room.room.clone()),
3381            channel_id: joined_room
3382                .channel
3383                .as_ref()
3384                .map(|channel| channel.id.to_proto()),
3385            live_kit_connection_info,
3386        })?;
3387
3388        let mut connection_pool = session.connection_pool().await;
3389        if let Some(membership_updated) = membership_updated {
3390            notify_membership_updated(
3391                &mut connection_pool,
3392                membership_updated,
3393                session.user_id(),
3394                &session.peer,
3395            );
3396        }
3397
3398        room_updated(&joined_room.room, &session.peer);
3399
3400        joined_room
3401    };
3402
3403    channel_updated(
3404        &joined_room.channel.context("channel not returned")?,
3405        &joined_room.room,
3406        &session.peer,
3407        &*session.connection_pool().await,
3408    );
3409
3410    update_user_contacts(session.user_id(), &session).await?;
3411    Ok(())
3412}
3413
3414/// Start editing the channel notes
3415async fn join_channel_buffer(
3416    request: proto::JoinChannelBuffer,
3417    response: Response<proto::JoinChannelBuffer>,
3418    session: MessageContext,
3419) -> Result<()> {
3420    let db = session.db().await;
3421    let channel_id = ChannelId::from_proto(request.channel_id);
3422
3423    let open_response = db
3424        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3425        .await?;
3426
3427    let collaborators = open_response.collaborators.clone();
3428    response.send(open_response)?;
3429
3430    let update = UpdateChannelBufferCollaborators {
3431        channel_id: channel_id.to_proto(),
3432        collaborators: collaborators.clone(),
3433    };
3434    channel_buffer_updated(
3435        session.connection_id,
3436        collaborators
3437            .iter()
3438            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3439        &update,
3440        &session.peer,
3441    );
3442
3443    Ok(())
3444}
3445
3446/// Edit the channel notes
3447async fn update_channel_buffer(
3448    request: proto::UpdateChannelBuffer,
3449    session: MessageContext,
3450) -> Result<()> {
3451    let db = session.db().await;
3452    let channel_id = ChannelId::from_proto(request.channel_id);
3453
3454    let (collaborators, epoch, version) = db
3455        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3456        .await?;
3457
3458    channel_buffer_updated(
3459        session.connection_id,
3460        collaborators.clone(),
3461        &proto::UpdateChannelBuffer {
3462            channel_id: channel_id.to_proto(),
3463            operations: request.operations,
3464        },
3465        &session.peer,
3466    );
3467
3468    let pool = &*session.connection_pool().await;
3469
3470    let non_collaborators =
3471        pool.channel_connection_ids(channel_id)
3472            .filter_map(|(connection_id, _)| {
3473                if collaborators.contains(&connection_id) {
3474                    None
3475                } else {
3476                    Some(connection_id)
3477                }
3478            });
3479
3480    broadcast(None, non_collaborators, |peer_id| {
3481        session.peer.send(
3482            peer_id,
3483            proto::UpdateChannels {
3484                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3485                    channel_id: channel_id.to_proto(),
3486                    epoch: epoch as u64,
3487                    version: version.clone(),
3488                }],
3489                ..Default::default()
3490            },
3491        )
3492    });
3493
3494    Ok(())
3495}
3496
3497/// Rejoin the channel notes after a connection blip
3498async fn rejoin_channel_buffers(
3499    request: proto::RejoinChannelBuffers,
3500    response: Response<proto::RejoinChannelBuffers>,
3501    session: MessageContext,
3502) -> Result<()> {
3503    let db = session.db().await;
3504    let buffers = db
3505        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3506        .await?;
3507
3508    for rejoined_buffer in &buffers {
3509        let collaborators_to_notify = rejoined_buffer
3510            .buffer
3511            .collaborators
3512            .iter()
3513            .filter_map(|c| Some(c.peer_id?.into()));
3514        channel_buffer_updated(
3515            session.connection_id,
3516            collaborators_to_notify,
3517            &proto::UpdateChannelBufferCollaborators {
3518                channel_id: rejoined_buffer.buffer.channel_id,
3519                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3520            },
3521            &session.peer,
3522        );
3523    }
3524
3525    response.send(proto::RejoinChannelBuffersResponse {
3526        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3527    })?;
3528
3529    Ok(())
3530}
3531
3532/// Stop editing the channel notes
3533async fn leave_channel_buffer(
3534    request: proto::LeaveChannelBuffer,
3535    response: Response<proto::LeaveChannelBuffer>,
3536    session: MessageContext,
3537) -> Result<()> {
3538    let db = session.db().await;
3539    let channel_id = ChannelId::from_proto(request.channel_id);
3540
3541    let left_buffer = db
3542        .leave_channel_buffer(channel_id, session.connection_id)
3543        .await?;
3544
3545    response.send(Ack {})?;
3546
3547    channel_buffer_updated(
3548        session.connection_id,
3549        left_buffer.connections,
3550        &proto::UpdateChannelBufferCollaborators {
3551            channel_id: channel_id.to_proto(),
3552            collaborators: left_buffer.collaborators,
3553        },
3554        &session.peer,
3555    );
3556
3557    Ok(())
3558}
3559
3560fn channel_buffer_updated<T: EnvelopedMessage>(
3561    sender_id: ConnectionId,
3562    collaborators: impl IntoIterator<Item = ConnectionId>,
3563    message: &T,
3564    peer: &Peer,
3565) {
3566    broadcast(Some(sender_id), collaborators, |peer_id| {
3567        peer.send(peer_id, message.clone())
3568    });
3569}
3570
3571fn send_notifications(
3572    connection_pool: &ConnectionPool,
3573    peer: &Peer,
3574    notifications: db::NotificationBatch,
3575) {
3576    for (user_id, notification) in notifications {
3577        for connection_id in connection_pool.user_connection_ids(user_id) {
3578            if let Err(error) = peer.send(
3579                connection_id,
3580                proto::AddNotification {
3581                    notification: Some(notification.clone()),
3582                },
3583            ) {
3584                tracing::error!(
3585                    "failed to send notification to {:?} {}",
3586                    connection_id,
3587                    error
3588                );
3589            }
3590        }
3591    }
3592}
3593
3594/// Send a message to the channel
3595async fn send_channel_message(
3596    _request: proto::SendChannelMessage,
3597    _response: Response<proto::SendChannelMessage>,
3598    _session: MessageContext,
3599) -> Result<()> {
3600    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3601}
3602
3603/// Delete a channel message
3604async fn remove_channel_message(
3605    _request: proto::RemoveChannelMessage,
3606    _response: Response<proto::RemoveChannelMessage>,
3607    _session: MessageContext,
3608) -> Result<()> {
3609    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3610}
3611
3612async fn update_channel_message(
3613    _request: proto::UpdateChannelMessage,
3614    _response: Response<proto::UpdateChannelMessage>,
3615    _session: MessageContext,
3616) -> Result<()> {
3617    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3618}
3619
3620/// Mark a channel message as read
3621async fn acknowledge_channel_message(
3622    _request: proto::AckChannelMessage,
3623    _session: MessageContext,
3624) -> Result<()> {
3625    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3626}
3627
3628/// Mark a buffer version as synced
3629async fn acknowledge_buffer_version(
3630    request: proto::AckBufferOperation,
3631    session: MessageContext,
3632) -> Result<()> {
3633    let buffer_id = BufferId::from_proto(request.buffer_id);
3634    session
3635        .db()
3636        .await
3637        .observe_buffer_version(
3638            buffer_id,
3639            session.user_id(),
3640            request.epoch as i32,
3641            &request.version,
3642        )
3643        .await?;
3644    Ok(())
3645}
3646
3647/// Get a Supermaven API key for the user
3648async fn get_supermaven_api_key(
3649    _request: proto::GetSupermavenApiKey,
3650    response: Response<proto::GetSupermavenApiKey>,
3651    session: MessageContext,
3652) -> Result<()> {
3653    let user_id: String = session.user_id().to_string();
3654    if !session.is_staff() {
3655        return Err(anyhow!("supermaven not enabled for this account"))?;
3656    }
3657
3658    let email = session.email().context("user must have an email")?;
3659
3660    let supermaven_admin_api = session
3661        .supermaven_client
3662        .as_ref()
3663        .context("supermaven not configured")?;
3664
3665    let result = supermaven_admin_api
3666        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3667        .await?;
3668
3669    response.send(proto::GetSupermavenApiKeyResponse {
3670        api_key: result.api_key,
3671    })?;
3672
3673    Ok(())
3674}
3675
3676/// Start receiving chat updates for a channel
3677async fn join_channel_chat(
3678    _request: proto::JoinChannelChat,
3679    _response: Response<proto::JoinChannelChat>,
3680    _session: MessageContext,
3681) -> Result<()> {
3682    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3683}
3684
3685/// Stop receiving chat updates for a channel
3686async fn leave_channel_chat(
3687    _request: proto::LeaveChannelChat,
3688    _session: MessageContext,
3689) -> Result<()> {
3690    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3691}
3692
3693/// Retrieve the chat history for a channel
3694async fn get_channel_messages(
3695    _request: proto::GetChannelMessages,
3696    _response: Response<proto::GetChannelMessages>,
3697    _session: MessageContext,
3698) -> Result<()> {
3699    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3700}
3701
3702/// Retrieve specific chat messages
3703async fn get_channel_messages_by_id(
3704    _request: proto::GetChannelMessagesById,
3705    _response: Response<proto::GetChannelMessagesById>,
3706    _session: MessageContext,
3707) -> Result<()> {
3708    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3709}
3710
3711/// Retrieve the current users notifications
3712async fn get_notifications(
3713    request: proto::GetNotifications,
3714    response: Response<proto::GetNotifications>,
3715    session: MessageContext,
3716) -> Result<()> {
3717    let notifications = session
3718        .db()
3719        .await
3720        .get_notifications(
3721            session.user_id(),
3722            NOTIFICATION_COUNT_PER_PAGE,
3723            request.before_id.map(db::NotificationId::from_proto),
3724        )
3725        .await?;
3726    response.send(proto::GetNotificationsResponse {
3727        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3728        notifications,
3729    })?;
3730    Ok(())
3731}
3732
3733/// Mark notifications as read
3734async fn mark_notification_as_read(
3735    request: proto::MarkNotificationRead,
3736    response: Response<proto::MarkNotificationRead>,
3737    session: MessageContext,
3738) -> Result<()> {
3739    let database = &session.db().await;
3740    let notifications = database
3741        .mark_notification_as_read_by_id(
3742            session.user_id(),
3743            NotificationId::from_proto(request.notification_id),
3744        )
3745        .await?;
3746    send_notifications(
3747        &*session.connection_pool().await,
3748        &session.peer,
3749        notifications,
3750    );
3751    response.send(proto::Ack {})?;
3752    Ok(())
3753}
3754
3755fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3756    let message = match message {
3757        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3758        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3759        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3760        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3761        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3762            code: frame.code.into(),
3763            reason: frame.reason.as_str().to_owned().into(),
3764        })),
3765        // We should never receive a frame while reading the message, according
3766        // to the `tungstenite` maintainers:
3767        //
3768        // > It cannot occur when you read messages from the WebSocket, but it
3769        // > can be used when you want to send the raw frames (e.g. you want to
3770        // > send the frames to the WebSocket without composing the full message first).
3771        // >
3772        // > — https://github.com/snapview/tungstenite-rs/issues/268
3773        TungsteniteMessage::Frame(_) => {
3774            bail!("received an unexpected frame while reading the message")
3775        }
3776    };
3777
3778    Ok(message)
3779}
3780
3781fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3782    match message {
3783        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3784        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3785        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3786        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3787        AxumMessage::Close(frame) => {
3788            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3789                code: frame.code.into(),
3790                reason: frame.reason.as_ref().into(),
3791            }))
3792        }
3793    }
3794}
3795
3796fn notify_membership_updated(
3797    connection_pool: &mut ConnectionPool,
3798    result: MembershipUpdated,
3799    user_id: UserId,
3800    peer: &Peer,
3801) {
3802    for membership in &result.new_channels.channel_memberships {
3803        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3804    }
3805    for channel_id in &result.removed_channels {
3806        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3807    }
3808
3809    let user_channels_update = proto::UpdateUserChannels {
3810        channel_memberships: result
3811            .new_channels
3812            .channel_memberships
3813            .iter()
3814            .map(|cm| proto::ChannelMembership {
3815                channel_id: cm.channel_id.to_proto(),
3816                role: cm.role.into(),
3817            })
3818            .collect(),
3819        ..Default::default()
3820    };
3821
3822    let mut update = build_channels_update(result.new_channels);
3823    update.delete_channels = result
3824        .removed_channels
3825        .into_iter()
3826        .map(|id| id.to_proto())
3827        .collect();
3828    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3829
3830    for connection_id in connection_pool.user_connection_ids(user_id) {
3831        peer.send(connection_id, user_channels_update.clone())
3832            .trace_err();
3833        peer.send(connection_id, update.clone()).trace_err();
3834    }
3835}
3836
3837fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3838    proto::UpdateUserChannels {
3839        channel_memberships: channels
3840            .channel_memberships
3841            .iter()
3842            .map(|m| proto::ChannelMembership {
3843                channel_id: m.channel_id.to_proto(),
3844                role: m.role.into(),
3845            })
3846            .collect(),
3847        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3848    }
3849}
3850
3851fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3852    let mut update = proto::UpdateChannels::default();
3853
3854    for channel in channels.channels {
3855        update.channels.push(channel.to_proto());
3856    }
3857
3858    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3859
3860    for (channel_id, participants) in channels.channel_participants {
3861        update
3862            .channel_participants
3863            .push(proto::ChannelParticipants {
3864                channel_id: channel_id.to_proto(),
3865                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3866            });
3867    }
3868
3869    for channel in channels.invited_channels {
3870        update.channel_invitations.push(channel.to_proto());
3871    }
3872
3873    update
3874}
3875
3876fn build_initial_contacts_update(
3877    contacts: Vec<db::Contact>,
3878    pool: &ConnectionPool,
3879) -> proto::UpdateContacts {
3880    let mut update = proto::UpdateContacts::default();
3881
3882    for contact in contacts {
3883        match contact {
3884            db::Contact::Accepted { user_id, busy } => {
3885                update.contacts.push(contact_for_user(user_id, busy, pool));
3886            }
3887            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3888            db::Contact::Incoming { user_id } => {
3889                update
3890                    .incoming_requests
3891                    .push(proto::IncomingContactRequest {
3892                        requester_id: user_id.to_proto(),
3893                    })
3894            }
3895        }
3896    }
3897
3898    update
3899}
3900
3901fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3902    proto::Contact {
3903        user_id: user_id.to_proto(),
3904        online: pool.is_user_online(user_id),
3905        busy,
3906    }
3907}
3908
3909fn room_updated(room: &proto::Room, peer: &Peer) {
3910    broadcast(
3911        None,
3912        room.participants
3913            .iter()
3914            .filter_map(|participant| Some(participant.peer_id?.into())),
3915        |peer_id| {
3916            peer.send(
3917                peer_id,
3918                proto::RoomUpdated {
3919                    room: Some(room.clone()),
3920                },
3921            )
3922        },
3923    );
3924}
3925
3926fn channel_updated(
3927    channel: &db::channel::Model,
3928    room: &proto::Room,
3929    peer: &Peer,
3930    pool: &ConnectionPool,
3931) {
3932    let participants = room
3933        .participants
3934        .iter()
3935        .map(|p| p.user_id)
3936        .collect::<Vec<_>>();
3937
3938    broadcast(
3939        None,
3940        pool.channel_connection_ids(channel.root_id())
3941            .filter_map(|(channel_id, role)| {
3942                role.can_see_channel(channel.visibility)
3943                    .then_some(channel_id)
3944            }),
3945        |peer_id| {
3946            peer.send(
3947                peer_id,
3948                proto::UpdateChannels {
3949                    channel_participants: vec![proto::ChannelParticipants {
3950                        channel_id: channel.id.to_proto(),
3951                        participant_user_ids: participants.clone(),
3952                    }],
3953                    ..Default::default()
3954                },
3955            )
3956        },
3957    );
3958}
3959
3960async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3961    let db = session.db().await;
3962
3963    let contacts = db.get_contacts(user_id).await?;
3964    let busy = db.is_user_busy(user_id).await?;
3965
3966    let pool = session.connection_pool().await;
3967    let updated_contact = contact_for_user(user_id, busy, &pool);
3968    for contact in contacts {
3969        if let db::Contact::Accepted {
3970            user_id: contact_user_id,
3971            ..
3972        } = contact
3973        {
3974            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3975                session
3976                    .peer
3977                    .send(
3978                        contact_conn_id,
3979                        proto::UpdateContacts {
3980                            contacts: vec![updated_contact.clone()],
3981                            remove_contacts: Default::default(),
3982                            incoming_requests: Default::default(),
3983                            remove_incoming_requests: Default::default(),
3984                            outgoing_requests: Default::default(),
3985                            remove_outgoing_requests: Default::default(),
3986                        },
3987                    )
3988                    .trace_err();
3989            }
3990        }
3991    }
3992    Ok(())
3993}
3994
3995async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3996    let mut contacts_to_update = HashSet::default();
3997
3998    let room_id;
3999    let canceled_calls_to_user_ids;
4000    let livekit_room;
4001    let delete_livekit_room;
4002    let room;
4003    let channel;
4004
4005    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4006        contacts_to_update.insert(session.user_id());
4007
4008        for project in left_room.left_projects.values() {
4009            project_left(project, session);
4010        }
4011
4012        room_id = RoomId::from_proto(left_room.room.id);
4013        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4014        livekit_room = mem::take(&mut left_room.room.livekit_room);
4015        delete_livekit_room = left_room.deleted;
4016        room = mem::take(&mut left_room.room);
4017        channel = mem::take(&mut left_room.channel);
4018
4019        room_updated(&room, &session.peer);
4020    } else {
4021        return Ok(());
4022    }
4023
4024    if let Some(channel) = channel {
4025        channel_updated(
4026            &channel,
4027            &room,
4028            &session.peer,
4029            &*session.connection_pool().await,
4030        );
4031    }
4032
4033    {
4034        let pool = session.connection_pool().await;
4035        for canceled_user_id in canceled_calls_to_user_ids {
4036            for connection_id in pool.user_connection_ids(canceled_user_id) {
4037                session
4038                    .peer
4039                    .send(
4040                        connection_id,
4041                        proto::CallCanceled {
4042                            room_id: room_id.to_proto(),
4043                        },
4044                    )
4045                    .trace_err();
4046            }
4047            contacts_to_update.insert(canceled_user_id);
4048        }
4049    }
4050
4051    for contact_user_id in contacts_to_update {
4052        update_user_contacts(contact_user_id, session).await?;
4053    }
4054
4055    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4056        live_kit
4057            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4058            .await
4059            .trace_err();
4060
4061        if delete_livekit_room {
4062            live_kit.delete_room(livekit_room).await.trace_err();
4063        }
4064    }
4065
4066    Ok(())
4067}
4068
4069async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4070    let left_channel_buffers = session
4071        .db()
4072        .await
4073        .leave_channel_buffers(session.connection_id)
4074        .await?;
4075
4076    for left_buffer in left_channel_buffers {
4077        channel_buffer_updated(
4078            session.connection_id,
4079            left_buffer.connections,
4080            &proto::UpdateChannelBufferCollaborators {
4081                channel_id: left_buffer.channel_id.to_proto(),
4082                collaborators: left_buffer.collaborators,
4083            },
4084            &session.peer,
4085        );
4086    }
4087
4088    Ok(())
4089}
4090
4091fn project_left(project: &db::LeftProject, session: &Session) {
4092    for connection_id in &project.connection_ids {
4093        if project.should_unshare {
4094            session
4095                .peer
4096                .send(
4097                    *connection_id,
4098                    proto::UnshareProject {
4099                        project_id: project.id.to_proto(),
4100                    },
4101                )
4102                .trace_err();
4103        } else {
4104            session
4105                .peer
4106                .send(
4107                    *connection_id,
4108                    proto::RemoveProjectCollaborator {
4109                        project_id: project.id.to_proto(),
4110                        peer_id: Some(session.connection_id.into()),
4111                    },
4112                )
4113                .trace_err();
4114        }
4115    }
4116}
4117
4118pub trait ResultExt {
4119    type Ok;
4120
4121    fn trace_err(self) -> Option<Self::Ok>;
4122}
4123
4124impl<T, E> ResultExt for Result<T, E>
4125where
4126    E: std::fmt::Debug,
4127{
4128    type Ok = T;
4129
4130    #[track_caller]
4131    fn trace_err(self) -> Option<T> {
4132        match self {
4133            Ok(value) => Some(value),
4134            Err(error) => {
4135                tracing::error!("{:?}", error);
4136                None
4137            }
4138        }
4139    }
4140}