rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
  10        UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use rpc::proto::split_repository_update;
  37use tracing::Span;
  38use util::paths::PathStyle;
  39
  40use futures::{
  41    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  42    stream::FuturesUnordered,
  43};
  44use prometheus::{IntGauge, register_int_gauge};
  45use rpc::{
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  49        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51};
  52use semver::Version;
  53use serde::{Serialize, Serializer};
  54use std::{
  55    any::TypeId,
  56    future::Future,
  57    marker::PhantomData,
  58    mem,
  59    net::SocketAddr,
  60    ops::{Deref, DerefMut},
  61    rc::Rc,
  62    sync::{
  63        Arc, OnceLock,
  64        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  65    },
  66    time::{Duration, Instant},
  67};
  68use tokio::sync::{Semaphore, watch};
  69use tower::ServiceBuilder;
  70use tracing::{
  71    Instrument,
  72    field::{self},
  73    info_span, instrument,
  74};
  75
  76pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  77
  78// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  79pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  80
  81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  82const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  83
  84static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  85
  86const TOTAL_DURATION_MS: &str = "total_duration_ms";
  87const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  88const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  89const HOST_WAITING_MS: &str = "host_waiting_ms";
  90
  91type MessageHandler =
  92    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  93
  94pub struct ConnectionGuard;
  95
  96impl ConnectionGuard {
  97    pub fn try_acquire() -> Result<Self, ()> {
  98        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  99        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 100            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 101            tracing::error!(
 102                "too many concurrent connections: {}",
 103                current_connections + 1
 104            );
 105            return Err(());
 106        }
 107        Ok(ConnectionGuard)
 108    }
 109}
 110
 111impl Drop for ConnectionGuard {
 112    fn drop(&mut self) {
 113        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 114    }
 115}
 116
 117struct Response<R> {
 118    peer: Arc<Peer>,
 119    receipt: Receipt<R>,
 120    responded: Arc<AtomicBool>,
 121}
 122
 123impl<R: RequestMessage> Response<R> {
 124    fn send(self, payload: R::Response) -> Result<()> {
 125        self.responded.store(true, SeqCst);
 126        self.peer.respond(self.receipt, payload)?;
 127        Ok(())
 128    }
 129}
 130
 131#[derive(Clone, Debug)]
 132pub enum Principal {
 133    User(User),
 134    Impersonated { user: User, admin: User },
 135}
 136
 137impl Principal {
 138    fn update_span(&self, span: &tracing::Span) {
 139        match &self {
 140            Principal::User(user) => {
 141                span.record("user_id", user.id.0);
 142                span.record("login", &user.github_login);
 143            }
 144            Principal::Impersonated { user, admin } => {
 145                span.record("user_id", user.id.0);
 146                span.record("login", &user.github_login);
 147                span.record("impersonator", &admin.github_login);
 148            }
 149        }
 150    }
 151}
 152
 153#[derive(Clone)]
 154struct MessageContext {
 155    session: Session,
 156    span: tracing::Span,
 157}
 158
 159impl Deref for MessageContext {
 160    type Target = Session;
 161
 162    fn deref(&self) -> &Self::Target {
 163        &self.session
 164    }
 165}
 166
 167impl MessageContext {
 168    pub fn forward_request<T: RequestMessage>(
 169        &self,
 170        receiver_id: ConnectionId,
 171        request: T,
 172    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 173        let request_start_time = Instant::now();
 174        let span = self.span.clone();
 175        tracing::info!("start forwarding request");
 176        self.peer
 177            .forward_request(self.connection_id, receiver_id, request)
 178            .inspect(move |_| {
 179                span.record(
 180                    HOST_WAITING_MS,
 181                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 182                );
 183            })
 184            .inspect_err(|_| tracing::error!("error forwarding request"))
 185            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 186    }
 187}
 188
 189#[derive(Clone)]
 190struct Session {
 191    principal: Principal,
 192    connection_id: ConnectionId,
 193    db: Arc<tokio::sync::Mutex<DbHandle>>,
 194    peer: Arc<Peer>,
 195    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 196    app_state: Arc<AppState>,
 197    /// The GeoIP country code for the user.
 198    #[allow(unused)]
 199    geoip_country_code: Option<String>,
 200    #[allow(unused)]
 201    system_id: Option<String>,
 202    _executor: Executor,
 203}
 204
 205impl Session {
 206    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 207        #[cfg(feature = "test-support")]
 208        tokio::task::yield_now().await;
 209        let guard = self.db.lock().await;
 210        #[cfg(feature = "test-support")]
 211        tokio::task::yield_now().await;
 212        guard
 213    }
 214
 215    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 216        #[cfg(feature = "test-support")]
 217        tokio::task::yield_now().await;
 218        let guard = self.connection_pool.lock();
 219        ConnectionPoolGuard {
 220            guard,
 221            _not_send: PhantomData,
 222        }
 223    }
 224
 225    #[expect(dead_code)]
 226    fn is_staff(&self) -> bool {
 227        match &self.principal {
 228            Principal::User(user) => user.admin,
 229            Principal::Impersonated { .. } => true,
 230        }
 231    }
 232
 233    fn user_id(&self) -> UserId {
 234        match &self.principal {
 235            Principal::User(user) => user.id,
 236            Principal::Impersonated { user, .. } => user.id,
 237        }
 238    }
 239}
 240
 241impl Debug for Session {
 242    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 243        let mut result = f.debug_struct("Session");
 244        match &self.principal {
 245            Principal::User(user) => {
 246                result.field("user", &user.github_login);
 247            }
 248            Principal::Impersonated { user, admin } => {
 249                result.field("user", &user.github_login);
 250                result.field("impersonator", &admin.github_login);
 251            }
 252        }
 253        result.field("connection_id", &self.connection_id).finish()
 254    }
 255}
 256
 257struct DbHandle(Arc<Database>);
 258
 259impl Deref for DbHandle {
 260    type Target = Database;
 261
 262    fn deref(&self) -> &Self::Target {
 263        self.0.as_ref()
 264    }
 265}
 266
 267pub struct Server {
 268    id: parking_lot::Mutex<ServerId>,
 269    peer: Arc<Peer>,
 270    pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 271    app_state: Arc<AppState>,
 272    handlers: HashMap<TypeId, MessageHandler>,
 273    teardown: watch::Sender<bool>,
 274}
 275
 276struct ConnectionPoolGuard<'a> {
 277    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 278    _not_send: PhantomData<Rc<()>>,
 279}
 280
 281#[derive(Serialize)]
 282pub struct ServerSnapshot<'a> {
 283    peer: &'a Peer,
 284    #[serde(serialize_with = "serialize_deref")]
 285    connection_pool: ConnectionPoolGuard<'a>,
 286}
 287
 288pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 289where
 290    S: Serializer,
 291    T: Deref<Target = U>,
 292    U: Serialize,
 293{
 294    Serialize::serialize(value.deref(), serializer)
 295}
 296
 297impl Server {
 298    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 299        let mut server = Self {
 300            id: parking_lot::Mutex::new(id),
 301            peer: Peer::new(id.0 as u32),
 302            app_state,
 303            connection_pool: Default::default(),
 304            handlers: Default::default(),
 305            teardown: watch::channel(false).0,
 306        };
 307
 308        server
 309            .add_request_handler(ping)
 310            .add_request_handler(create_room)
 311            .add_request_handler(join_room)
 312            .add_request_handler(rejoin_room)
 313            .add_request_handler(leave_room)
 314            .add_request_handler(set_room_participant_role)
 315            .add_request_handler(call)
 316            .add_request_handler(cancel_call)
 317            .add_message_handler(decline_call)
 318            .add_request_handler(update_participant_location)
 319            .add_request_handler(share_project)
 320            .add_message_handler(unshare_project)
 321            .add_request_handler(join_project)
 322            .add_message_handler(leave_project)
 323            .add_request_handler(update_project)
 324            .add_request_handler(update_worktree)
 325            .add_request_handler(update_repository)
 326            .add_request_handler(remove_repository)
 327            .add_message_handler(start_language_server)
 328            .add_message_handler(update_language_server)
 329            .add_message_handler(update_diagnostic_summary)
 330            .add_message_handler(update_worktree_settings)
 331            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 332            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 333            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 334            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 335            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 336            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 337            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 338            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 339            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 340            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 341            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 342            .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 344            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 345            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 346            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 347            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 348            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 349            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 350            .add_request_handler(
 351                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 352            )
 353            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 354            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 355            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 356            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 357            .add_request_handler(
 358                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 359            )
 360            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 361            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 362            .add_request_handler(
 363                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 364            )
 365            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 366            .add_request_handler(
 367                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 368            )
 369            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 370            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 371            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 372            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 373            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 374            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 375            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 376            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 377            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 378            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 379            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 380            .add_request_handler(
 381                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 382            )
 383            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 384            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 385            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 386            .add_request_handler(lsp_query)
 387            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 388            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 389            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 390            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 391            .add_message_handler(create_buffer_for_peer)
 392            .add_message_handler(create_image_for_peer)
 393            .add_request_handler(update_buffer)
 394            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 395            .add_message_handler(
 396                broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
 397            )
 398            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 399            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 400            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 401            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 402            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 403            .add_message_handler(
 404                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 405            )
 406            .add_request_handler(get_users)
 407            .add_request_handler(fuzzy_search_users)
 408            .add_request_handler(request_contact)
 409            .add_request_handler(remove_contact)
 410            .add_request_handler(respond_to_contact_request)
 411            .add_message_handler(subscribe_to_channels)
 412            .add_request_handler(create_channel)
 413            .add_request_handler(delete_channel)
 414            .add_request_handler(invite_channel_member)
 415            .add_request_handler(remove_channel_member)
 416            .add_request_handler(set_channel_member_role)
 417            .add_request_handler(set_channel_visibility)
 418            .add_request_handler(rename_channel)
 419            .add_request_handler(join_channel_buffer)
 420            .add_request_handler(leave_channel_buffer)
 421            .add_message_handler(update_channel_buffer)
 422            .add_request_handler(rejoin_channel_buffers)
 423            .add_request_handler(get_channel_members)
 424            .add_request_handler(respond_to_channel_invite)
 425            .add_request_handler(join_channel)
 426            .add_request_handler(join_channel_chat)
 427            .add_message_handler(leave_channel_chat)
 428            .add_request_handler(send_channel_message)
 429            .add_request_handler(remove_channel_message)
 430            .add_request_handler(update_channel_message)
 431            .add_request_handler(get_channel_messages)
 432            .add_request_handler(get_channel_messages_by_id)
 433            .add_request_handler(get_notifications)
 434            .add_request_handler(mark_notification_as_read)
 435            .add_request_handler(move_channel)
 436            .add_request_handler(reorder_channel)
 437            .add_request_handler(follow)
 438            .add_message_handler(unfollow)
 439            .add_message_handler(update_followers)
 440            .add_message_handler(acknowledge_channel_message)
 441            .add_message_handler(acknowledge_buffer_version)
 442            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 443            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 444            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 445            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 446            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 447            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 448            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 449            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 450            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 451            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 452            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 453            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 454            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 455            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 456            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 457            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 458            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 459            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 460            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 461            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 462            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 463            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 464            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 465            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 466            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 467            .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
 468            .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
 469            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 470            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 471            .add_message_handler(update_context)
 472            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 473            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
 474            .add_request_handler(share_agent_thread)
 475            .add_request_handler(get_shared_agent_thread)
 476            .add_request_handler(forward_project_search_chunk);
 477
 478        Arc::new(server)
 479    }
 480
 481    pub async fn start(&self) -> Result<()> {
 482        let server_id = *self.id.lock();
 483        let app_state = self.app_state.clone();
 484        let peer = self.peer.clone();
 485        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 486        let pool = self.connection_pool.clone();
 487        let livekit_client = self.app_state.livekit_client.clone();
 488
 489        let span = info_span!("start server");
 490        self.app_state.executor.spawn_detached(
 491            async move {
 492                tracing::info!("waiting for cleanup timeout");
 493                timeout.await;
 494                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 495
 496                app_state
 497                    .db
 498                    .delete_stale_channel_chat_participants(
 499                        &app_state.config.zed_environment,
 500                        server_id,
 501                    )
 502                    .await
 503                    .trace_err();
 504
 505                if let Some((room_ids, channel_ids)) = app_state
 506                    .db
 507                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 508                    .await
 509                    .trace_err()
 510                {
 511                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 512                    tracing::info!(
 513                        stale_channel_buffer_count = channel_ids.len(),
 514                        "retrieved stale channel buffers"
 515                    );
 516
 517                    for channel_id in channel_ids {
 518                        if let Some(refreshed_channel_buffer) = app_state
 519                            .db
 520                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 521                            .await
 522                            .trace_err()
 523                        {
 524                            for connection_id in refreshed_channel_buffer.connection_ids {
 525                                peer.send(
 526                                    connection_id,
 527                                    proto::UpdateChannelBufferCollaborators {
 528                                        channel_id: channel_id.to_proto(),
 529                                        collaborators: refreshed_channel_buffer
 530                                            .collaborators
 531                                            .clone(),
 532                                    },
 533                                )
 534                                .trace_err();
 535                            }
 536                        }
 537                    }
 538
 539                    for room_id in room_ids {
 540                        let mut contacts_to_update = HashSet::default();
 541                        let mut canceled_calls_to_user_ids = Vec::new();
 542                        let mut livekit_room = String::new();
 543                        let mut delete_livekit_room = false;
 544
 545                        if let Some(mut refreshed_room) = app_state
 546                            .db
 547                            .clear_stale_room_participants(room_id, server_id)
 548                            .await
 549                            .trace_err()
 550                        {
 551                            tracing::info!(
 552                                room_id = room_id.0,
 553                                new_participant_count = refreshed_room.room.participants.len(),
 554                                "refreshed room"
 555                            );
 556                            room_updated(&refreshed_room.room, &peer);
 557                            if let Some(channel) = refreshed_room.channel.as_ref() {
 558                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 559                            }
 560                            contacts_to_update
 561                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 562                            contacts_to_update
 563                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 564                            canceled_calls_to_user_ids =
 565                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 566                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 567                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 568                        }
 569
 570                        {
 571                            let pool = pool.lock();
 572                            for canceled_user_id in canceled_calls_to_user_ids {
 573                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 574                                    peer.send(
 575                                        connection_id,
 576                                        proto::CallCanceled {
 577                                            room_id: room_id.to_proto(),
 578                                        },
 579                                    )
 580                                    .trace_err();
 581                                }
 582                            }
 583                        }
 584
 585                        for user_id in contacts_to_update {
 586                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 587                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 588                            if let Some((busy, contacts)) = busy.zip(contacts) {
 589                                let pool = pool.lock();
 590                                let updated_contact = contact_for_user(user_id, busy, &pool);
 591                                for contact in contacts {
 592                                    if let db::Contact::Accepted {
 593                                        user_id: contact_user_id,
 594                                        ..
 595                                    } = contact
 596                                    {
 597                                        for contact_conn_id in
 598                                            pool.user_connection_ids(contact_user_id)
 599                                        {
 600                                            peer.send(
 601                                                contact_conn_id,
 602                                                proto::UpdateContacts {
 603                                                    contacts: vec![updated_contact.clone()],
 604                                                    remove_contacts: Default::default(),
 605                                                    incoming_requests: Default::default(),
 606                                                    remove_incoming_requests: Default::default(),
 607                                                    outgoing_requests: Default::default(),
 608                                                    remove_outgoing_requests: Default::default(),
 609                                                },
 610                                            )
 611                                            .trace_err();
 612                                        }
 613                                    }
 614                                }
 615                            }
 616                        }
 617
 618                        if let Some(live_kit) = livekit_client.as_ref()
 619                            && delete_livekit_room
 620                        {
 621                            live_kit.delete_room(livekit_room).await.trace_err();
 622                        }
 623                    }
 624                }
 625
 626                app_state
 627                    .db
 628                    .delete_stale_channel_chat_participants(
 629                        &app_state.config.zed_environment,
 630                        server_id,
 631                    )
 632                    .await
 633                    .trace_err();
 634
 635                app_state
 636                    .db
 637                    .clear_old_worktree_entries(server_id)
 638                    .await
 639                    .trace_err();
 640
 641                app_state
 642                    .db
 643                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 644                    .await
 645                    .trace_err();
 646            }
 647            .instrument(span),
 648        );
 649        Ok(())
 650    }
 651
 652    pub fn teardown(&self) {
 653        self.peer.teardown();
 654        self.connection_pool.lock().reset();
 655        let _ = self.teardown.send(true);
 656    }
 657
 658    #[cfg(feature = "test-support")]
 659    pub fn reset(&self, id: ServerId) {
 660        self.teardown();
 661        *self.id.lock() = id;
 662        self.peer.reset(id.0 as u32);
 663        let _ = self.teardown.send(false);
 664    }
 665
 666    #[cfg(feature = "test-support")]
 667    pub fn id(&self) -> ServerId {
 668        *self.id.lock()
 669    }
 670
 671    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 672    where
 673        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 674        Fut: 'static + Send + Future<Output = Result<()>>,
 675        M: EnvelopedMessage,
 676    {
 677        let prev_handler = self.handlers.insert(
 678            TypeId::of::<M>(),
 679            Box::new(move |envelope, session, span| {
 680                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 681                let received_at = envelope.received_at;
 682                tracing::info!("message received");
 683                let start_time = Instant::now();
 684                let future = (handler)(
 685                    *envelope,
 686                    MessageContext {
 687                        session,
 688                        span: span.clone(),
 689                    },
 690                );
 691                async move {
 692                    let result = future.await;
 693                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 694                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 695                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 696                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 697                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 698                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 699                    match result {
 700                        Err(error) => {
 701                            tracing::error!(?error, "error handling message")
 702                        }
 703                        Ok(()) => tracing::info!("finished handling message"),
 704                    }
 705                }
 706                .boxed()
 707            }),
 708        );
 709        if prev_handler.is_some() {
 710            panic!("registered a handler for the same message twice");
 711        }
 712        self
 713    }
 714
 715    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 716    where
 717        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 718        Fut: 'static + Send + Future<Output = Result<()>>,
 719        M: EnvelopedMessage,
 720    {
 721        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 722        self
 723    }
 724
 725    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 726    where
 727        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 728        Fut: Send + Future<Output = Result<()>>,
 729        M: RequestMessage,
 730    {
 731        let handler = Arc::new(handler);
 732        self.add_handler(move |envelope, session| {
 733            let receipt = envelope.receipt();
 734            let handler = handler.clone();
 735            async move {
 736                let peer = session.peer.clone();
 737                let responded = Arc::new(AtomicBool::default());
 738                let response = Response {
 739                    peer: peer.clone(),
 740                    responded: responded.clone(),
 741                    receipt,
 742                };
 743                match (handler)(envelope.payload, response, session).await {
 744                    Ok(()) => {
 745                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 746                            Ok(())
 747                        } else {
 748                            Err(anyhow!("handler did not send a response"))?
 749                        }
 750                    }
 751                    Err(error) => {
 752                        let proto_err = match &error {
 753                            Error::Internal(err) => err.to_proto(),
 754                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 755                        };
 756                        peer.respond_with_error(receipt, proto_err)?;
 757                        Err(error)
 758                    }
 759                }
 760            }
 761        })
 762    }
 763
 764    pub fn handle_connection(
 765        self: &Arc<Self>,
 766        connection: Connection,
 767        address: String,
 768        principal: Principal,
 769        zed_version: ZedVersion,
 770        release_channel: Option<String>,
 771        user_agent: Option<String>,
 772        geoip_country_code: Option<String>,
 773        system_id: Option<String>,
 774        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 775        executor: Executor,
 776        connection_guard: Option<ConnectionGuard>,
 777    ) -> impl Future<Output = ()> + use<> {
 778        let this = self.clone();
 779        let span = info_span!("handle connection", %address,
 780            connection_id=field::Empty,
 781            user_id=field::Empty,
 782            login=field::Empty,
 783            impersonator=field::Empty,
 784            user_agent=field::Empty,
 785            geoip_country_code=field::Empty,
 786            release_channel=field::Empty,
 787        );
 788        principal.update_span(&span);
 789        if let Some(user_agent) = user_agent {
 790            span.record("user_agent", user_agent);
 791        }
 792        if let Some(release_channel) = release_channel {
 793            span.record("release_channel", release_channel);
 794        }
 795
 796        if let Some(country_code) = geoip_country_code.as_ref() {
 797            span.record("geoip_country_code", country_code);
 798        }
 799
 800        let mut teardown = self.teardown.subscribe();
 801        async move {
 802            if *teardown.borrow() {
 803                tracing::error!("server is tearing down");
 804                return;
 805            }
 806
 807            let (connection_id, handle_io, mut incoming_rx) =
 808                this.peer.add_connection(connection, {
 809                    let executor = executor.clone();
 810                    move |duration| executor.sleep(duration)
 811                });
 812            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 813
 814            tracing::info!("connection opened");
 815
 816            let session = Session {
 817                principal: principal.clone(),
 818                connection_id,
 819                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 820                peer: this.peer.clone(),
 821                connection_pool: this.connection_pool.clone(),
 822                app_state: this.app_state.clone(),
 823                geoip_country_code,
 824                system_id,
 825                _executor: executor.clone(),
 826            };
 827
 828            if let Err(error) = this
 829                .send_initial_client_update(
 830                    connection_id,
 831                    zed_version,
 832                    send_connection_id,
 833                    &session,
 834                )
 835                .await
 836            {
 837                tracing::error!(?error, "failed to send initial client update");
 838                return;
 839            }
 840            drop(connection_guard);
 841
 842            let handle_io = handle_io.fuse();
 843            futures::pin_mut!(handle_io);
 844
 845            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 846            // This prevents deadlocks when e.g., client A performs a request to client B and
 847            // client B performs a request to client A. If both clients stop processing further
 848            // messages until their respective request completes, they won't have a chance to
 849            // respond to the other client's request and cause a deadlock.
 850            //
 851            // This arrangement ensures we will attempt to process earlier messages first, but fall
 852            // back to processing messages arrived later in the spirit of making progress.
 853            const MAX_CONCURRENT_HANDLERS: usize = 256;
 854            let mut foreground_message_handlers = FuturesUnordered::new();
 855            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 856            let get_concurrent_handlers = {
 857                let concurrent_handlers = concurrent_handlers.clone();
 858                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 859            };
 860            loop {
 861                let next_message = async {
 862                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 863                    let message = incoming_rx.next().await;
 864                    // Cache the concurrent_handlers here, so that we know what the
 865                    // queue looks like as each handler starts
 866                    (permit, message, get_concurrent_handlers())
 867                }
 868                .fuse();
 869                futures::pin_mut!(next_message);
 870                futures::select_biased! {
 871                    _ = teardown.changed().fuse() => return,
 872                    result = handle_io => {
 873                        if let Err(error) = result {
 874                            tracing::error!(?error, "error handling I/O");
 875                        }
 876                        break;
 877                    }
 878                    _ = foreground_message_handlers.next() => {}
 879                    next_message = next_message => {
 880                        let (permit, message, concurrent_handlers) = next_message;
 881                        if let Some(message) = message {
 882                            let type_name = message.payload_type_name();
 883                            // note: we copy all the fields from the parent span so we can query them in the logs.
 884                            // (https://github.com/tokio-rs/tracing/issues/2670).
 885                            let span = tracing::info_span!("receive message",
 886                                %connection_id,
 887                                %address,
 888                                type_name,
 889                                concurrent_handlers,
 890                                user_id=field::Empty,
 891                                login=field::Empty,
 892                                impersonator=field::Empty,
 893                                lsp_query_request=field::Empty,
 894                                release_channel=field::Empty,
 895                                { TOTAL_DURATION_MS }=field::Empty,
 896                                { PROCESSING_DURATION_MS }=field::Empty,
 897                                { QUEUE_DURATION_MS }=field::Empty,
 898                                { HOST_WAITING_MS }=field::Empty
 899                            );
 900                            principal.update_span(&span);
 901                            let span_enter = span.enter();
 902                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 903                                let is_background = message.is_background();
 904                                let handle_message = (handler)(message, session.clone(), span.clone());
 905                                drop(span_enter);
 906
 907                                let handle_message = async move {
 908                                    handle_message.await;
 909                                    drop(permit);
 910                                }.instrument(span);
 911                                if is_background {
 912                                    executor.spawn_detached(handle_message);
 913                                } else {
 914                                    foreground_message_handlers.push(handle_message);
 915                                }
 916                            } else {
 917                                tracing::error!("no message handler");
 918                            }
 919                        } else {
 920                            tracing::info!("connection closed");
 921                            break;
 922                        }
 923                    }
 924                }
 925            }
 926
 927            drop(foreground_message_handlers);
 928            let concurrent_handlers = get_concurrent_handlers();
 929            tracing::info!(concurrent_handlers, "signing out");
 930            if let Err(error) = connection_lost(session, teardown, executor).await {
 931                tracing::error!(?error, "error signing out");
 932            }
 933        }
 934        .instrument(span)
 935    }
 936
 937    async fn send_initial_client_update(
 938        &self,
 939        connection_id: ConnectionId,
 940        zed_version: ZedVersion,
 941        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 942        session: &Session,
 943    ) -> Result<()> {
 944        self.peer.send(
 945            connection_id,
 946            proto::Hello {
 947                peer_id: Some(connection_id.into()),
 948            },
 949        )?;
 950        tracing::info!("sent hello message");
 951        if let Some(send_connection_id) = send_connection_id.take() {
 952            let _ = send_connection_id.send(connection_id);
 953        }
 954
 955        match &session.principal {
 956            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 957                if !user.connected_once {
 958                    self.peer.send(connection_id, proto::ShowContacts {})?;
 959                    self.app_state
 960                        .db
 961                        .set_user_connected_once(user.id, true)
 962                        .await?;
 963                }
 964
 965                let contacts = self.app_state.db.get_contacts(user.id).await?;
 966
 967                {
 968                    let mut pool = self.connection_pool.lock();
 969                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 970                    self.peer.send(
 971                        connection_id,
 972                        build_initial_contacts_update(contacts, &pool),
 973                    )?;
 974                }
 975
 976                if should_auto_subscribe_to_channels(&zed_version) {
 977                    subscribe_user_to_channels(user.id, session).await?;
 978                }
 979
 980                if let Some(incoming_call) =
 981                    self.app_state.db.incoming_call_for_user(user.id).await?
 982                {
 983                    self.peer.send(connection_id, incoming_call)?;
 984                }
 985
 986                update_user_contacts(user.id, session).await?;
 987            }
 988        }
 989
 990        Ok(())
 991    }
 992
 993    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
 994        ServerSnapshot {
 995            connection_pool: ConnectionPoolGuard {
 996                guard: self.connection_pool.lock(),
 997                _not_send: PhantomData,
 998            },
 999            peer: &self.peer,
1000        }
1001    }
1002}
1003
1004impl Deref for ConnectionPoolGuard<'_> {
1005    type Target = ConnectionPool;
1006
1007    fn deref(&self) -> &Self::Target {
1008        &self.guard
1009    }
1010}
1011
1012impl DerefMut for ConnectionPoolGuard<'_> {
1013    fn deref_mut(&mut self) -> &mut Self::Target {
1014        &mut self.guard
1015    }
1016}
1017
1018impl Drop for ConnectionPoolGuard<'_> {
1019    fn drop(&mut self) {
1020        #[cfg(feature = "test-support")]
1021        self.check_invariants();
1022    }
1023}
1024
1025fn broadcast<F>(
1026    sender_id: Option<ConnectionId>,
1027    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1028    mut f: F,
1029) where
1030    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1031{
1032    for receiver_id in receiver_ids {
1033        if Some(receiver_id) != sender_id
1034            && let Err(error) = f(receiver_id)
1035        {
1036            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1037        }
1038    }
1039}
1040
1041pub struct ProtocolVersion(u32);
1042
1043impl Header for ProtocolVersion {
1044    fn name() -> &'static HeaderName {
1045        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1046        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1047    }
1048
1049    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1050    where
1051        Self: Sized,
1052        I: Iterator<Item = &'i axum::http::HeaderValue>,
1053    {
1054        let version = values
1055            .next()
1056            .ok_or_else(axum::headers::Error::invalid)?
1057            .to_str()
1058            .map_err(|_| axum::headers::Error::invalid())?
1059            .parse()
1060            .map_err(|_| axum::headers::Error::invalid())?;
1061        Ok(Self(version))
1062    }
1063
1064    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1065        values.extend([self.0.to_string().parse().unwrap()]);
1066    }
1067}
1068
1069pub struct AppVersionHeader(Version);
1070impl Header for AppVersionHeader {
1071    fn name() -> &'static HeaderName {
1072        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1073        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1074    }
1075
1076    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1077    where
1078        Self: Sized,
1079        I: Iterator<Item = &'i axum::http::HeaderValue>,
1080    {
1081        let version = values
1082            .next()
1083            .ok_or_else(axum::headers::Error::invalid)?
1084            .to_str()
1085            .map_err(|_| axum::headers::Error::invalid())?
1086            .parse()
1087            .map_err(|_| axum::headers::Error::invalid())?;
1088        Ok(Self(version))
1089    }
1090
1091    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1092        values.extend([self.0.to_string().parse().unwrap()]);
1093    }
1094}
1095
1096#[derive(Debug)]
1097pub struct ReleaseChannelHeader(String);
1098
1099impl Header for ReleaseChannelHeader {
1100    fn name() -> &'static HeaderName {
1101        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1102        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1103    }
1104
1105    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1106    where
1107        Self: Sized,
1108        I: Iterator<Item = &'i axum::http::HeaderValue>,
1109    {
1110        Ok(Self(
1111            values
1112                .next()
1113                .ok_or_else(axum::headers::Error::invalid)?
1114                .to_str()
1115                .map_err(|_| axum::headers::Error::invalid())?
1116                .to_owned(),
1117        ))
1118    }
1119
1120    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1121        values.extend([self.0.parse().unwrap()]);
1122    }
1123}
1124
1125pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1126    Router::new()
1127        .route("/rpc", get(handle_websocket_request))
1128        .layer(
1129            ServiceBuilder::new()
1130                .layer(Extension(server.app_state.clone()))
1131                .layer(middleware::from_fn(auth::validate_header)),
1132        )
1133        .route("/metrics", get(handle_metrics))
1134        .layer(Extension(server))
1135}
1136
1137pub async fn handle_websocket_request(
1138    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1139    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1140    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1141    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1142    Extension(server): Extension<Arc<Server>>,
1143    Extension(principal): Extension<Principal>,
1144    user_agent: Option<TypedHeader<UserAgent>>,
1145    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1146    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1147    ws: WebSocketUpgrade,
1148) -> axum::response::Response {
1149    if protocol_version != rpc::PROTOCOL_VERSION {
1150        return (
1151            StatusCode::UPGRADE_REQUIRED,
1152            "client must be upgraded".to_string(),
1153        )
1154            .into_response();
1155    }
1156
1157    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1158        return (
1159            StatusCode::UPGRADE_REQUIRED,
1160            "no version header found".to_string(),
1161        )
1162            .into_response();
1163    };
1164
1165    let release_channel = release_channel_header.map(|header| header.0.0);
1166
1167    if !version.can_collaborate() {
1168        return (
1169            StatusCode::UPGRADE_REQUIRED,
1170            "client must be upgraded".to_string(),
1171        )
1172            .into_response();
1173    }
1174
1175    let socket_address = socket_address.to_string();
1176
1177    // Acquire connection guard before WebSocket upgrade
1178    let connection_guard = match ConnectionGuard::try_acquire() {
1179        Ok(guard) => guard,
1180        Err(()) => {
1181            return (
1182                StatusCode::SERVICE_UNAVAILABLE,
1183                "Too many concurrent connections",
1184            )
1185                .into_response();
1186        }
1187    };
1188
1189    ws.on_upgrade(move |socket| {
1190        let socket = socket
1191            .map_ok(to_tungstenite_message)
1192            .err_into()
1193            .with(|message| async move { to_axum_message(message) });
1194        let connection = Connection::new(Box::pin(socket));
1195        async move {
1196            server
1197                .handle_connection(
1198                    connection,
1199                    socket_address,
1200                    principal,
1201                    version,
1202                    release_channel,
1203                    user_agent.map(|header| header.to_string()),
1204                    country_code_header.map(|header| header.to_string()),
1205                    system_id_header.map(|header| header.to_string()),
1206                    None,
1207                    Executor::Production,
1208                    Some(connection_guard),
1209                )
1210                .await;
1211        }
1212    })
1213}
1214
1215pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1216    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1217    let connections_metric = CONNECTIONS_METRIC
1218        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1219
1220    let connections = server
1221        .connection_pool
1222        .lock()
1223        .connections()
1224        .filter(|connection| !connection.admin)
1225        .count();
1226    connections_metric.set(connections as _);
1227
1228    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1229    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1230        register_int_gauge!(
1231            "shared_projects",
1232            "number of open projects with one or more guests"
1233        )
1234        .unwrap()
1235    });
1236
1237    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1238    shared_projects_metric.set(shared_projects as _);
1239
1240    let encoder = prometheus::TextEncoder::new();
1241    let metric_families = prometheus::gather();
1242    let encoded_metrics = encoder
1243        .encode_to_string(&metric_families)
1244        .map_err(|err| anyhow!("{err}"))?;
1245    Ok(encoded_metrics)
1246}
1247
1248#[instrument(err, skip(executor))]
1249async fn connection_lost(
1250    session: Session,
1251    mut teardown: watch::Receiver<bool>,
1252    executor: Executor,
1253) -> Result<()> {
1254    session.peer.disconnect(session.connection_id);
1255    session
1256        .connection_pool()
1257        .await
1258        .remove_connection(session.connection_id)?;
1259
1260    session
1261        .db()
1262        .await
1263        .connection_lost(session.connection_id)
1264        .await
1265        .trace_err();
1266
1267    futures::select_biased! {
1268        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1269
1270            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1271            leave_room_for_session(&session, session.connection_id).await.trace_err();
1272            leave_channel_buffers_for_session(&session)
1273                .await
1274                .trace_err();
1275
1276            if !session
1277                .connection_pool()
1278                .await
1279                .is_user_online(session.user_id())
1280            {
1281                let db = session.db().await;
1282                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1283                    room_updated(&room, &session.peer);
1284                }
1285            }
1286
1287            update_user_contacts(session.user_id(), &session).await?;
1288        },
1289        _ = teardown.changed().fuse() => {}
1290    }
1291
1292    Ok(())
1293}
1294
1295/// Acknowledges a ping from a client, used to keep the connection alive.
1296async fn ping(
1297    _: proto::Ping,
1298    response: Response<proto::Ping>,
1299    _session: MessageContext,
1300) -> Result<()> {
1301    response.send(proto::Ack {})?;
1302    Ok(())
1303}
1304
1305/// Creates a new room for calling (outside of channels)
1306async fn create_room(
1307    _request: proto::CreateRoom,
1308    response: Response<proto::CreateRoom>,
1309    session: MessageContext,
1310) -> Result<()> {
1311    let livekit_room = nanoid::nanoid!(30);
1312
1313    let live_kit_connection_info = util::maybe!(async {
1314        let live_kit = session.app_state.livekit_client.as_ref();
1315        let live_kit = live_kit?;
1316        let user_id = session.user_id().to_string();
1317
1318        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1319
1320        Some(proto::LiveKitConnectionInfo {
1321            server_url: live_kit.url().into(),
1322            token,
1323            can_publish: true,
1324        })
1325    })
1326    .await;
1327
1328    let room = session
1329        .db()
1330        .await
1331        .create_room(session.user_id(), session.connection_id, &livekit_room)
1332        .await?;
1333
1334    response.send(proto::CreateRoomResponse {
1335        room: Some(room.clone()),
1336        live_kit_connection_info,
1337    })?;
1338
1339    update_user_contacts(session.user_id(), &session).await?;
1340    Ok(())
1341}
1342
1343/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1344async fn join_room(
1345    request: proto::JoinRoom,
1346    response: Response<proto::JoinRoom>,
1347    session: MessageContext,
1348) -> Result<()> {
1349    let room_id = RoomId::from_proto(request.id);
1350
1351    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1352
1353    if let Some(channel_id) = channel_id {
1354        return join_channel_internal(channel_id, Box::new(response), session).await;
1355    }
1356
1357    let joined_room = {
1358        let room = session
1359            .db()
1360            .await
1361            .join_room(room_id, session.user_id(), session.connection_id)
1362            .await?;
1363        room_updated(&room.room, &session.peer);
1364        room.into_inner()
1365    };
1366
1367    for connection_id in session
1368        .connection_pool()
1369        .await
1370        .user_connection_ids(session.user_id())
1371    {
1372        session
1373            .peer
1374            .send(
1375                connection_id,
1376                proto::CallCanceled {
1377                    room_id: room_id.to_proto(),
1378                },
1379            )
1380            .trace_err();
1381    }
1382
1383    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1384    {
1385        live_kit
1386            .room_token(
1387                &joined_room.room.livekit_room,
1388                &session.user_id().to_string(),
1389            )
1390            .trace_err()
1391            .map(|token| proto::LiveKitConnectionInfo {
1392                server_url: live_kit.url().into(),
1393                token,
1394                can_publish: true,
1395            })
1396    } else {
1397        None
1398    };
1399
1400    response.send(proto::JoinRoomResponse {
1401        room: Some(joined_room.room),
1402        channel_id: None,
1403        live_kit_connection_info,
1404    })?;
1405
1406    update_user_contacts(session.user_id(), &session).await?;
1407    Ok(())
1408}
1409
1410/// Rejoin room is used to reconnect to a room after connection errors.
1411async fn rejoin_room(
1412    request: proto::RejoinRoom,
1413    response: Response<proto::RejoinRoom>,
1414    session: MessageContext,
1415) -> Result<()> {
1416    let room;
1417    let channel;
1418    {
1419        let mut rejoined_room = session
1420            .db()
1421            .await
1422            .rejoin_room(request, session.user_id(), session.connection_id)
1423            .await?;
1424
1425        response.send(proto::RejoinRoomResponse {
1426            room: Some(rejoined_room.room.clone()),
1427            reshared_projects: rejoined_room
1428                .reshared_projects
1429                .iter()
1430                .map(|project| proto::ResharedProject {
1431                    id: project.id.to_proto(),
1432                    collaborators: project
1433                        .collaborators
1434                        .iter()
1435                        .map(|collaborator| collaborator.to_proto())
1436                        .collect(),
1437                })
1438                .collect(),
1439            rejoined_projects: rejoined_room
1440                .rejoined_projects
1441                .iter()
1442                .map(|rejoined_project| rejoined_project.to_proto())
1443                .collect(),
1444        })?;
1445        room_updated(&rejoined_room.room, &session.peer);
1446
1447        for project in &rejoined_room.reshared_projects {
1448            for collaborator in &project.collaborators {
1449                session
1450                    .peer
1451                    .send(
1452                        collaborator.connection_id,
1453                        proto::UpdateProjectCollaborator {
1454                            project_id: project.id.to_proto(),
1455                            old_peer_id: Some(project.old_connection_id.into()),
1456                            new_peer_id: Some(session.connection_id.into()),
1457                        },
1458                    )
1459                    .trace_err();
1460            }
1461
1462            broadcast(
1463                Some(session.connection_id),
1464                project
1465                    .collaborators
1466                    .iter()
1467                    .map(|collaborator| collaborator.connection_id),
1468                |connection_id| {
1469                    session.peer.forward_send(
1470                        session.connection_id,
1471                        connection_id,
1472                        proto::UpdateProject {
1473                            project_id: project.id.to_proto(),
1474                            worktrees: project.worktrees.clone(),
1475                        },
1476                    )
1477                },
1478            );
1479        }
1480
1481        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1482
1483        let rejoined_room = rejoined_room.into_inner();
1484
1485        room = rejoined_room.room;
1486        channel = rejoined_room.channel;
1487    }
1488
1489    if let Some(channel) = channel {
1490        channel_updated(
1491            &channel,
1492            &room,
1493            &session.peer,
1494            &*session.connection_pool().await,
1495        );
1496    }
1497
1498    update_user_contacts(session.user_id(), &session).await?;
1499    Ok(())
1500}
1501
1502fn notify_rejoined_projects(
1503    rejoined_projects: &mut Vec<RejoinedProject>,
1504    session: &Session,
1505) -> Result<()> {
1506    for project in rejoined_projects.iter() {
1507        for collaborator in &project.collaborators {
1508            session
1509                .peer
1510                .send(
1511                    collaborator.connection_id,
1512                    proto::UpdateProjectCollaborator {
1513                        project_id: project.id.to_proto(),
1514                        old_peer_id: Some(project.old_connection_id.into()),
1515                        new_peer_id: Some(session.connection_id.into()),
1516                    },
1517                )
1518                .trace_err();
1519        }
1520    }
1521
1522    for project in rejoined_projects {
1523        for worktree in mem::take(&mut project.worktrees) {
1524            // Stream this worktree's entries.
1525            let message = proto::UpdateWorktree {
1526                project_id: project.id.to_proto(),
1527                worktree_id: worktree.id,
1528                abs_path: worktree.abs_path.clone(),
1529                root_name: worktree.root_name,
1530                updated_entries: worktree.updated_entries,
1531                removed_entries: worktree.removed_entries,
1532                scan_id: worktree.scan_id,
1533                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1534                updated_repositories: worktree.updated_repositories,
1535                removed_repositories: worktree.removed_repositories,
1536            };
1537            for update in proto::split_worktree_update(message) {
1538                session.peer.send(session.connection_id, update)?;
1539            }
1540
1541            // Stream this worktree's diagnostics.
1542            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1543            if let Some(summary) = worktree_diagnostics.next() {
1544                let message = proto::UpdateDiagnosticSummary {
1545                    project_id: project.id.to_proto(),
1546                    worktree_id: worktree.id,
1547                    summary: Some(summary),
1548                    more_summaries: worktree_diagnostics.collect(),
1549                };
1550                session.peer.send(session.connection_id, message)?;
1551            }
1552
1553            for settings_file in worktree.settings_files {
1554                session.peer.send(
1555                    session.connection_id,
1556                    proto::UpdateWorktreeSettings {
1557                        project_id: project.id.to_proto(),
1558                        worktree_id: worktree.id,
1559                        path: settings_file.path,
1560                        content: Some(settings_file.content),
1561                        kind: Some(settings_file.kind.to_proto().into()),
1562                        outside_worktree: Some(settings_file.outside_worktree),
1563                    },
1564                )?;
1565            }
1566        }
1567
1568        for repository in mem::take(&mut project.updated_repositories) {
1569            for update in split_repository_update(repository) {
1570                session.peer.send(session.connection_id, update)?;
1571            }
1572        }
1573
1574        for id in mem::take(&mut project.removed_repositories) {
1575            session.peer.send(
1576                session.connection_id,
1577                proto::RemoveRepository {
1578                    project_id: project.id.to_proto(),
1579                    id,
1580                },
1581            )?;
1582        }
1583    }
1584
1585    Ok(())
1586}
1587
1588/// leave room disconnects from the room.
1589async fn leave_room(
1590    _: proto::LeaveRoom,
1591    response: Response<proto::LeaveRoom>,
1592    session: MessageContext,
1593) -> Result<()> {
1594    leave_room_for_session(&session, session.connection_id).await?;
1595    response.send(proto::Ack {})?;
1596    Ok(())
1597}
1598
1599/// Updates the permissions of someone else in the room.
1600async fn set_room_participant_role(
1601    request: proto::SetRoomParticipantRole,
1602    response: Response<proto::SetRoomParticipantRole>,
1603    session: MessageContext,
1604) -> Result<()> {
1605    let user_id = UserId::from_proto(request.user_id);
1606    let role = ChannelRole::from(request.role());
1607
1608    let (livekit_room, can_publish) = {
1609        let room = session
1610            .db()
1611            .await
1612            .set_room_participant_role(
1613                session.user_id(),
1614                RoomId::from_proto(request.room_id),
1615                user_id,
1616                role,
1617            )
1618            .await?;
1619
1620        let livekit_room = room.livekit_room.clone();
1621        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1622        room_updated(&room, &session.peer);
1623        (livekit_room, can_publish)
1624    };
1625
1626    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1627        live_kit
1628            .update_participant(
1629                livekit_room.clone(),
1630                request.user_id.to_string(),
1631                livekit_api::proto::ParticipantPermission {
1632                    can_subscribe: true,
1633                    can_publish,
1634                    can_publish_data: can_publish,
1635                    hidden: false,
1636                    recorder: false,
1637                },
1638            )
1639            .await
1640            .trace_err();
1641    }
1642
1643    response.send(proto::Ack {})?;
1644    Ok(())
1645}
1646
1647/// Call someone else into the current room
1648async fn call(
1649    request: proto::Call,
1650    response: Response<proto::Call>,
1651    session: MessageContext,
1652) -> Result<()> {
1653    let room_id = RoomId::from_proto(request.room_id);
1654    let calling_user_id = session.user_id();
1655    let calling_connection_id = session.connection_id;
1656    let called_user_id = UserId::from_proto(request.called_user_id);
1657    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1658    if !session
1659        .db()
1660        .await
1661        .has_contact(calling_user_id, called_user_id)
1662        .await?
1663    {
1664        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1665    }
1666
1667    let incoming_call = {
1668        let (room, incoming_call) = &mut *session
1669            .db()
1670            .await
1671            .call(
1672                room_id,
1673                calling_user_id,
1674                calling_connection_id,
1675                called_user_id,
1676                initial_project_id,
1677            )
1678            .await?;
1679        room_updated(room, &session.peer);
1680        mem::take(incoming_call)
1681    };
1682    update_user_contacts(called_user_id, &session).await?;
1683
1684    let mut calls = session
1685        .connection_pool()
1686        .await
1687        .user_connection_ids(called_user_id)
1688        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1689        .collect::<FuturesUnordered<_>>();
1690
1691    while let Some(call_response) = calls.next().await {
1692        match call_response.as_ref() {
1693            Ok(_) => {
1694                response.send(proto::Ack {})?;
1695                return Ok(());
1696            }
1697            Err(_) => {
1698                call_response.trace_err();
1699            }
1700        }
1701    }
1702
1703    {
1704        let room = session
1705            .db()
1706            .await
1707            .call_failed(room_id, called_user_id)
1708            .await?;
1709        room_updated(&room, &session.peer);
1710    }
1711    update_user_contacts(called_user_id, &session).await?;
1712
1713    Err(anyhow!("failed to ring user"))?
1714}
1715
1716/// Cancel an outgoing call.
1717async fn cancel_call(
1718    request: proto::CancelCall,
1719    response: Response<proto::CancelCall>,
1720    session: MessageContext,
1721) -> Result<()> {
1722    let called_user_id = UserId::from_proto(request.called_user_id);
1723    let room_id = RoomId::from_proto(request.room_id);
1724    {
1725        let room = session
1726            .db()
1727            .await
1728            .cancel_call(room_id, session.connection_id, called_user_id)
1729            .await?;
1730        room_updated(&room, &session.peer);
1731    }
1732
1733    for connection_id in session
1734        .connection_pool()
1735        .await
1736        .user_connection_ids(called_user_id)
1737    {
1738        session
1739            .peer
1740            .send(
1741                connection_id,
1742                proto::CallCanceled {
1743                    room_id: room_id.to_proto(),
1744                },
1745            )
1746            .trace_err();
1747    }
1748    response.send(proto::Ack {})?;
1749
1750    update_user_contacts(called_user_id, &session).await?;
1751    Ok(())
1752}
1753
1754/// Decline an incoming call.
1755async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1756    let room_id = RoomId::from_proto(message.room_id);
1757    {
1758        let room = session
1759            .db()
1760            .await
1761            .decline_call(Some(room_id), session.user_id())
1762            .await?
1763            .context("declining call")?;
1764        room_updated(&room, &session.peer);
1765    }
1766
1767    for connection_id in session
1768        .connection_pool()
1769        .await
1770        .user_connection_ids(session.user_id())
1771    {
1772        session
1773            .peer
1774            .send(
1775                connection_id,
1776                proto::CallCanceled {
1777                    room_id: room_id.to_proto(),
1778                },
1779            )
1780            .trace_err();
1781    }
1782    update_user_contacts(session.user_id(), &session).await?;
1783    Ok(())
1784}
1785
1786/// Updates other participants in the room with your current location.
1787async fn update_participant_location(
1788    request: proto::UpdateParticipantLocation,
1789    response: Response<proto::UpdateParticipantLocation>,
1790    session: MessageContext,
1791) -> Result<()> {
1792    let room_id = RoomId::from_proto(request.room_id);
1793    let location = request.location.context("invalid location")?;
1794
1795    let db = session.db().await;
1796    let room = db
1797        .update_room_participant_location(room_id, session.connection_id, location)
1798        .await?;
1799
1800    room_updated(&room, &session.peer);
1801    response.send(proto::Ack {})?;
1802    Ok(())
1803}
1804
1805/// Share a project into the room.
1806async fn share_project(
1807    request: proto::ShareProject,
1808    response: Response<proto::ShareProject>,
1809    session: MessageContext,
1810) -> Result<()> {
1811    let (project_id, room) = &*session
1812        .db()
1813        .await
1814        .share_project(
1815            RoomId::from_proto(request.room_id),
1816            session.connection_id,
1817            &request.worktrees,
1818            request.is_ssh_project,
1819            request.windows_paths.unwrap_or(false),
1820        )
1821        .await?;
1822    response.send(proto::ShareProjectResponse {
1823        project_id: project_id.to_proto(),
1824    })?;
1825    room_updated(room, &session.peer);
1826
1827    Ok(())
1828}
1829
1830/// Unshare a project from the room.
1831async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1832    let project_id = ProjectId::from_proto(message.project_id);
1833    unshare_project_internal(project_id, session.connection_id, &session).await
1834}
1835
1836async fn unshare_project_internal(
1837    project_id: ProjectId,
1838    connection_id: ConnectionId,
1839    session: &Session,
1840) -> Result<()> {
1841    let delete = {
1842        let room_guard = session
1843            .db()
1844            .await
1845            .unshare_project(project_id, connection_id)
1846            .await?;
1847
1848        let (delete, room, guest_connection_ids) = &*room_guard;
1849
1850        let message = proto::UnshareProject {
1851            project_id: project_id.to_proto(),
1852        };
1853
1854        broadcast(
1855            Some(connection_id),
1856            guest_connection_ids.iter().copied(),
1857            |conn_id| session.peer.send(conn_id, message.clone()),
1858        );
1859        if let Some(room) = room {
1860            room_updated(room, &session.peer);
1861        }
1862
1863        *delete
1864    };
1865
1866    if delete {
1867        let db = session.db().await;
1868        db.delete_project(project_id).await?;
1869    }
1870
1871    Ok(())
1872}
1873
1874/// Join someone elses shared project.
1875async fn join_project(
1876    request: proto::JoinProject,
1877    response: Response<proto::JoinProject>,
1878    session: MessageContext,
1879) -> Result<()> {
1880    let project_id = ProjectId::from_proto(request.project_id);
1881
1882    tracing::info!(%project_id, "join project");
1883
1884    let db = session.db().await;
1885    let (project, replica_id) = &mut *db
1886        .join_project(
1887            project_id,
1888            session.connection_id,
1889            session.user_id(),
1890            request.committer_name.clone(),
1891            request.committer_email.clone(),
1892        )
1893        .await?;
1894    drop(db);
1895    tracing::info!(%project_id, "join remote project");
1896    let collaborators = project
1897        .collaborators
1898        .iter()
1899        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1900        .map(|collaborator| collaborator.to_proto())
1901        .collect::<Vec<_>>();
1902    let project_id = project.id;
1903    let guest_user_id = session.user_id();
1904
1905    let worktrees = project
1906        .worktrees
1907        .iter()
1908        .map(|(id, worktree)| proto::WorktreeMetadata {
1909            id: *id,
1910            root_name: worktree.root_name.clone(),
1911            visible: worktree.visible,
1912            abs_path: worktree.abs_path.clone(),
1913        })
1914        .collect::<Vec<_>>();
1915
1916    let add_project_collaborator = proto::AddProjectCollaborator {
1917        project_id: project_id.to_proto(),
1918        collaborator: Some(proto::Collaborator {
1919            peer_id: Some(session.connection_id.into()),
1920            replica_id: replica_id.0 as u32,
1921            user_id: guest_user_id.to_proto(),
1922            is_host: false,
1923            committer_name: request.committer_name.clone(),
1924            committer_email: request.committer_email.clone(),
1925        }),
1926    };
1927
1928    for collaborator in &collaborators {
1929        session
1930            .peer
1931            .send(
1932                collaborator.peer_id.unwrap().into(),
1933                add_project_collaborator.clone(),
1934            )
1935            .trace_err();
1936    }
1937
1938    // First, we send the metadata associated with each worktree.
1939    let (language_servers, language_server_capabilities) = project
1940        .language_servers
1941        .clone()
1942        .into_iter()
1943        .map(|server| (server.server, server.capabilities))
1944        .unzip();
1945    response.send(proto::JoinProjectResponse {
1946        project_id: project.id.0 as u64,
1947        worktrees,
1948        replica_id: replica_id.0 as u32,
1949        collaborators,
1950        language_servers,
1951        language_server_capabilities,
1952        role: project.role.into(),
1953        windows_paths: project.path_style == PathStyle::Windows,
1954    })?;
1955
1956    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1957        // Stream this worktree's entries.
1958        let message = proto::UpdateWorktree {
1959            project_id: project_id.to_proto(),
1960            worktree_id,
1961            abs_path: worktree.abs_path.clone(),
1962            root_name: worktree.root_name,
1963            updated_entries: worktree.entries,
1964            removed_entries: Default::default(),
1965            scan_id: worktree.scan_id,
1966            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1967            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1968            removed_repositories: Default::default(),
1969        };
1970        for update in proto::split_worktree_update(message) {
1971            session.peer.send(session.connection_id, update.clone())?;
1972        }
1973
1974        // Stream this worktree's diagnostics.
1975        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1976        if let Some(summary) = worktree_diagnostics.next() {
1977            let message = proto::UpdateDiagnosticSummary {
1978                project_id: project.id.to_proto(),
1979                worktree_id: worktree.id,
1980                summary: Some(summary),
1981                more_summaries: worktree_diagnostics.collect(),
1982            };
1983            session.peer.send(session.connection_id, message)?;
1984        }
1985
1986        for settings_file in worktree.settings_files {
1987            session.peer.send(
1988                session.connection_id,
1989                proto::UpdateWorktreeSettings {
1990                    project_id: project_id.to_proto(),
1991                    worktree_id: worktree.id,
1992                    path: settings_file.path,
1993                    content: Some(settings_file.content),
1994                    kind: Some(settings_file.kind.to_proto() as i32),
1995                    outside_worktree: Some(settings_file.outside_worktree),
1996                },
1997            )?;
1998        }
1999    }
2000
2001    for repository in mem::take(&mut project.repositories) {
2002        for update in split_repository_update(repository) {
2003            session.peer.send(session.connection_id, update)?;
2004        }
2005    }
2006
2007    for language_server in &project.language_servers {
2008        session.peer.send(
2009            session.connection_id,
2010            proto::UpdateLanguageServer {
2011                project_id: project_id.to_proto(),
2012                server_name: Some(language_server.server.name.clone()),
2013                language_server_id: language_server.server.id,
2014                variant: Some(
2015                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2016                        proto::LspDiskBasedDiagnosticsUpdated {},
2017                    ),
2018                ),
2019            },
2020        )?;
2021    }
2022
2023    Ok(())
2024}
2025
2026/// Leave someone elses shared project.
2027async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2028    let sender_id = session.connection_id;
2029    let project_id = ProjectId::from_proto(request.project_id);
2030    let db = session.db().await;
2031
2032    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2033    tracing::info!(
2034        %project_id,
2035        "leave project"
2036    );
2037
2038    project_left(project, &session);
2039    if let Some(room) = room {
2040        room_updated(room, &session.peer);
2041    }
2042
2043    Ok(())
2044}
2045
2046/// Updates other participants with changes to the project
2047async fn update_project(
2048    request: proto::UpdateProject,
2049    response: Response<proto::UpdateProject>,
2050    session: MessageContext,
2051) -> Result<()> {
2052    let project_id = ProjectId::from_proto(request.project_id);
2053    let (room, guest_connection_ids) = &*session
2054        .db()
2055        .await
2056        .update_project(project_id, session.connection_id, &request.worktrees)
2057        .await?;
2058    broadcast(
2059        Some(session.connection_id),
2060        guest_connection_ids.iter().copied(),
2061        |connection_id| {
2062            session
2063                .peer
2064                .forward_send(session.connection_id, connection_id, request.clone())
2065        },
2066    );
2067    if let Some(room) = room {
2068        room_updated(room, &session.peer);
2069    }
2070    response.send(proto::Ack {})?;
2071
2072    Ok(())
2073}
2074
2075/// Updates other participants with changes to the worktree
2076async fn update_worktree(
2077    request: proto::UpdateWorktree,
2078    response: Response<proto::UpdateWorktree>,
2079    session: MessageContext,
2080) -> Result<()> {
2081    let guest_connection_ids = session
2082        .db()
2083        .await
2084        .update_worktree(&request, session.connection_id)
2085        .await?;
2086
2087    broadcast(
2088        Some(session.connection_id),
2089        guest_connection_ids.iter().copied(),
2090        |connection_id| {
2091            session
2092                .peer
2093                .forward_send(session.connection_id, connection_id, request.clone())
2094        },
2095    );
2096    response.send(proto::Ack {})?;
2097    Ok(())
2098}
2099
2100async fn update_repository(
2101    request: proto::UpdateRepository,
2102    response: Response<proto::UpdateRepository>,
2103    session: MessageContext,
2104) -> Result<()> {
2105    let guest_connection_ids = session
2106        .db()
2107        .await
2108        .update_repository(&request, session.connection_id)
2109        .await?;
2110
2111    broadcast(
2112        Some(session.connection_id),
2113        guest_connection_ids.iter().copied(),
2114        |connection_id| {
2115            session
2116                .peer
2117                .forward_send(session.connection_id, connection_id, request.clone())
2118        },
2119    );
2120    response.send(proto::Ack {})?;
2121    Ok(())
2122}
2123
2124async fn remove_repository(
2125    request: proto::RemoveRepository,
2126    response: Response<proto::RemoveRepository>,
2127    session: MessageContext,
2128) -> Result<()> {
2129    let guest_connection_ids = session
2130        .db()
2131        .await
2132        .remove_repository(&request, session.connection_id)
2133        .await?;
2134
2135    broadcast(
2136        Some(session.connection_id),
2137        guest_connection_ids.iter().copied(),
2138        |connection_id| {
2139            session
2140                .peer
2141                .forward_send(session.connection_id, connection_id, request.clone())
2142        },
2143    );
2144    response.send(proto::Ack {})?;
2145    Ok(())
2146}
2147
2148/// Updates other participants with changes to the diagnostics
2149async fn update_diagnostic_summary(
2150    message: proto::UpdateDiagnosticSummary,
2151    session: MessageContext,
2152) -> Result<()> {
2153    let guest_connection_ids = session
2154        .db()
2155        .await
2156        .update_diagnostic_summary(&message, session.connection_id)
2157        .await?;
2158
2159    broadcast(
2160        Some(session.connection_id),
2161        guest_connection_ids.iter().copied(),
2162        |connection_id| {
2163            session
2164                .peer
2165                .forward_send(session.connection_id, connection_id, message.clone())
2166        },
2167    );
2168
2169    Ok(())
2170}
2171
2172/// Updates other participants with changes to the worktree settings
2173async fn update_worktree_settings(
2174    message: proto::UpdateWorktreeSettings,
2175    session: MessageContext,
2176) -> Result<()> {
2177    let guest_connection_ids = session
2178        .db()
2179        .await
2180        .update_worktree_settings(&message, session.connection_id)
2181        .await?;
2182
2183    broadcast(
2184        Some(session.connection_id),
2185        guest_connection_ids.iter().copied(),
2186        |connection_id| {
2187            session
2188                .peer
2189                .forward_send(session.connection_id, connection_id, message.clone())
2190        },
2191    );
2192
2193    Ok(())
2194}
2195
2196/// Notify other participants that a language server has started.
2197async fn start_language_server(
2198    request: proto::StartLanguageServer,
2199    session: MessageContext,
2200) -> Result<()> {
2201    let guest_connection_ids = session
2202        .db()
2203        .await
2204        .start_language_server(&request, session.connection_id)
2205        .await?;
2206
2207    broadcast(
2208        Some(session.connection_id),
2209        guest_connection_ids.iter().copied(),
2210        |connection_id| {
2211            session
2212                .peer
2213                .forward_send(session.connection_id, connection_id, request.clone())
2214        },
2215    );
2216    Ok(())
2217}
2218
2219/// Notify other participants that a language server has changed.
2220async fn update_language_server(
2221    request: proto::UpdateLanguageServer,
2222    session: MessageContext,
2223) -> Result<()> {
2224    let project_id = ProjectId::from_proto(request.project_id);
2225    let db = session.db().await;
2226
2227    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2228        && let Some(capabilities) = update.capabilities.clone()
2229    {
2230        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2231            .await?;
2232    }
2233
2234    let project_connection_ids = db
2235        .project_connection_ids(project_id, session.connection_id, true)
2236        .await?;
2237    broadcast(
2238        Some(session.connection_id),
2239        project_connection_ids.iter().copied(),
2240        |connection_id| {
2241            session
2242                .peer
2243                .forward_send(session.connection_id, connection_id, request.clone())
2244        },
2245    );
2246    Ok(())
2247}
2248
2249/// forward a project request to the host. These requests should be read only
2250/// as guests are allowed to send them.
2251async fn forward_read_only_project_request<T>(
2252    request: T,
2253    response: Response<T>,
2254    session: MessageContext,
2255) -> Result<()>
2256where
2257    T: EntityMessage + RequestMessage,
2258{
2259    let project_id = ProjectId::from_proto(request.remote_entity_id());
2260    let host_connection_id = session
2261        .db()
2262        .await
2263        .host_for_read_only_project_request(project_id, session.connection_id)
2264        .await?;
2265    let payload = session.forward_request(host_connection_id, request).await?;
2266    response.send(payload)?;
2267    Ok(())
2268}
2269
2270/// forward a project request to the host. These requests are disallowed
2271/// for guests.
2272async fn forward_mutating_project_request<T>(
2273    request: T,
2274    response: Response<T>,
2275    session: MessageContext,
2276) -> Result<()>
2277where
2278    T: EntityMessage + RequestMessage,
2279{
2280    let project_id = ProjectId::from_proto(request.remote_entity_id());
2281
2282    let host_connection_id = session
2283        .db()
2284        .await
2285        .host_for_mutating_project_request(project_id, session.connection_id)
2286        .await?;
2287    let payload = session.forward_request(host_connection_id, request).await?;
2288    response.send(payload)?;
2289    Ok(())
2290}
2291
2292async fn lsp_query(
2293    request: proto::LspQuery,
2294    response: Response<proto::LspQuery>,
2295    session: MessageContext,
2296) -> Result<()> {
2297    let (name, should_write) = request.query_name_and_write_permissions();
2298    tracing::Span::current().record("lsp_query_request", name);
2299    tracing::info!("lsp_query message received");
2300    if should_write {
2301        forward_mutating_project_request(request, response, session).await
2302    } else {
2303        forward_read_only_project_request(request, response, session).await
2304    }
2305}
2306
2307/// Notify other participants that a new buffer has been created
2308async fn create_buffer_for_peer(
2309    request: proto::CreateBufferForPeer,
2310    session: MessageContext,
2311) -> Result<()> {
2312    session
2313        .db()
2314        .await
2315        .check_user_is_project_host(
2316            ProjectId::from_proto(request.project_id),
2317            session.connection_id,
2318        )
2319        .await?;
2320    let peer_id = request.peer_id.context("invalid peer id")?;
2321    session
2322        .peer
2323        .forward_send(session.connection_id, peer_id.into(), request)?;
2324    Ok(())
2325}
2326
2327/// Notify other participants that a new image has been created
2328async fn create_image_for_peer(
2329    request: proto::CreateImageForPeer,
2330    session: MessageContext,
2331) -> Result<()> {
2332    session
2333        .db()
2334        .await
2335        .check_user_is_project_host(
2336            ProjectId::from_proto(request.project_id),
2337            session.connection_id,
2338        )
2339        .await?;
2340    let peer_id = request.peer_id.context("invalid peer id")?;
2341    session
2342        .peer
2343        .forward_send(session.connection_id, peer_id.into(), request)?;
2344    Ok(())
2345}
2346
2347/// Notify other participants that a buffer has been updated. This is
2348/// allowed for guests as long as the update is limited to selections.
2349async fn update_buffer(
2350    request: proto::UpdateBuffer,
2351    response: Response<proto::UpdateBuffer>,
2352    session: MessageContext,
2353) -> Result<()> {
2354    let project_id = ProjectId::from_proto(request.project_id);
2355    let mut capability = Capability::ReadOnly;
2356
2357    for op in request.operations.iter() {
2358        match op.variant {
2359            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2360            Some(_) => capability = Capability::ReadWrite,
2361        }
2362    }
2363
2364    let host = {
2365        let guard = session
2366            .db()
2367            .await
2368            .connections_for_buffer_update(project_id, session.connection_id, capability)
2369            .await?;
2370
2371        let (host, guests) = &*guard;
2372
2373        broadcast(
2374            Some(session.connection_id),
2375            guests.clone(),
2376            |connection_id| {
2377                session
2378                    .peer
2379                    .forward_send(session.connection_id, connection_id, request.clone())
2380            },
2381        );
2382
2383        *host
2384    };
2385
2386    if host != session.connection_id {
2387        session.forward_request(host, request.clone()).await?;
2388    }
2389
2390    response.send(proto::Ack {})?;
2391    Ok(())
2392}
2393
2394async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2395    let project_id = ProjectId::from_proto(message.project_id);
2396
2397    let operation = message.operation.as_ref().context("invalid operation")?;
2398    let capability = match operation.variant.as_ref() {
2399        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2400            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2401                match buffer_op.variant {
2402                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2403                        Capability::ReadOnly
2404                    }
2405                    _ => Capability::ReadWrite,
2406                }
2407            } else {
2408                Capability::ReadWrite
2409            }
2410        }
2411        Some(_) => Capability::ReadWrite,
2412        None => Capability::ReadOnly,
2413    };
2414
2415    let guard = session
2416        .db()
2417        .await
2418        .connections_for_buffer_update(project_id, session.connection_id, capability)
2419        .await?;
2420
2421    let (host, guests) = &*guard;
2422
2423    broadcast(
2424        Some(session.connection_id),
2425        guests.iter().chain([host]).copied(),
2426        |connection_id| {
2427            session
2428                .peer
2429                .forward_send(session.connection_id, connection_id, message.clone())
2430        },
2431    );
2432
2433    Ok(())
2434}
2435
2436async fn forward_project_search_chunk(
2437    message: proto::FindSearchCandidatesChunk,
2438    response: Response<proto::FindSearchCandidatesChunk>,
2439    session: MessageContext,
2440) -> Result<()> {
2441    let peer_id = message.peer_id.context("missing peer_id")?;
2442    let payload = session
2443        .peer
2444        .forward_request(session.connection_id, peer_id.into(), message)
2445        .await?;
2446    response.send(payload)?;
2447    Ok(())
2448}
2449
2450/// Notify other participants that a project has been updated.
2451async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2452    request: T,
2453    session: MessageContext,
2454) -> Result<()> {
2455    let project_id = ProjectId::from_proto(request.remote_entity_id());
2456    let project_connection_ids = session
2457        .db()
2458        .await
2459        .project_connection_ids(project_id, session.connection_id, false)
2460        .await?;
2461
2462    broadcast(
2463        Some(session.connection_id),
2464        project_connection_ids.iter().copied(),
2465        |connection_id| {
2466            session
2467                .peer
2468                .forward_send(session.connection_id, connection_id, request.clone())
2469        },
2470    );
2471    Ok(())
2472}
2473
2474/// Start following another user in a call.
2475async fn follow(
2476    request: proto::Follow,
2477    response: Response<proto::Follow>,
2478    session: MessageContext,
2479) -> Result<()> {
2480    let room_id = RoomId::from_proto(request.room_id);
2481    let project_id = request.project_id.map(ProjectId::from_proto);
2482    let leader_id = request.leader_id.context("invalid leader id")?.into();
2483    let follower_id = session.connection_id;
2484
2485    session
2486        .db()
2487        .await
2488        .check_room_participants(room_id, leader_id, session.connection_id)
2489        .await?;
2490
2491    let response_payload = session.forward_request(leader_id, request).await?;
2492    response.send(response_payload)?;
2493
2494    if let Some(project_id) = project_id {
2495        let room = session
2496            .db()
2497            .await
2498            .follow(room_id, project_id, leader_id, follower_id)
2499            .await?;
2500        room_updated(&room, &session.peer);
2501    }
2502
2503    Ok(())
2504}
2505
2506/// Stop following another user in a call.
2507async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2508    let room_id = RoomId::from_proto(request.room_id);
2509    let project_id = request.project_id.map(ProjectId::from_proto);
2510    let leader_id = request.leader_id.context("invalid leader id")?.into();
2511    let follower_id = session.connection_id;
2512
2513    session
2514        .db()
2515        .await
2516        .check_room_participants(room_id, leader_id, session.connection_id)
2517        .await?;
2518
2519    session
2520        .peer
2521        .forward_send(session.connection_id, leader_id, request)?;
2522
2523    if let Some(project_id) = project_id {
2524        let room = session
2525            .db()
2526            .await
2527            .unfollow(room_id, project_id, leader_id, follower_id)
2528            .await?;
2529        room_updated(&room, &session.peer);
2530    }
2531
2532    Ok(())
2533}
2534
2535/// Notify everyone following you of your current location.
2536async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2537    let room_id = RoomId::from_proto(request.room_id);
2538    let database = session.db.lock().await;
2539
2540    let connection_ids = if let Some(project_id) = request.project_id {
2541        let project_id = ProjectId::from_proto(project_id);
2542        database
2543            .project_connection_ids(project_id, session.connection_id, true)
2544            .await?
2545    } else {
2546        database
2547            .room_connection_ids(room_id, session.connection_id)
2548            .await?
2549    };
2550
2551    // For now, don't send view update messages back to that view's current leader.
2552    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2553        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2554        _ => None,
2555    });
2556
2557    for connection_id in connection_ids.iter().cloned() {
2558        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2559            session
2560                .peer
2561                .forward_send(session.connection_id, connection_id, request.clone())?;
2562        }
2563    }
2564    Ok(())
2565}
2566
2567/// Get public data about users.
2568async fn get_users(
2569    request: proto::GetUsers,
2570    response: Response<proto::GetUsers>,
2571    session: MessageContext,
2572) -> Result<()> {
2573    let user_ids = request
2574        .user_ids
2575        .into_iter()
2576        .map(UserId::from_proto)
2577        .collect();
2578    let users = session
2579        .db()
2580        .await
2581        .get_users_by_ids(user_ids)
2582        .await?
2583        .into_iter()
2584        .map(|user| proto::User {
2585            id: user.id.to_proto(),
2586            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2587            github_login: user.github_login,
2588            name: user.name,
2589        })
2590        .collect();
2591    response.send(proto::UsersResponse { users })?;
2592    Ok(())
2593}
2594
2595/// Search for users (to invite) buy Github login
2596async fn fuzzy_search_users(
2597    request: proto::FuzzySearchUsers,
2598    response: Response<proto::FuzzySearchUsers>,
2599    session: MessageContext,
2600) -> Result<()> {
2601    let query = request.query;
2602    let users = match query.len() {
2603        0 => vec![],
2604        1 | 2 => session
2605            .db()
2606            .await
2607            .get_user_by_github_login(&query)
2608            .await?
2609            .into_iter()
2610            .collect(),
2611        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2612    };
2613    let users = users
2614        .into_iter()
2615        .filter(|user| user.id != session.user_id())
2616        .map(|user| proto::User {
2617            id: user.id.to_proto(),
2618            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2619            github_login: user.github_login,
2620            name: user.name,
2621        })
2622        .collect();
2623    response.send(proto::UsersResponse { users })?;
2624    Ok(())
2625}
2626
2627/// Send a contact request to another user.
2628async fn request_contact(
2629    request: proto::RequestContact,
2630    response: Response<proto::RequestContact>,
2631    session: MessageContext,
2632) -> Result<()> {
2633    let requester_id = session.user_id();
2634    let responder_id = UserId::from_proto(request.responder_id);
2635    if requester_id == responder_id {
2636        return Err(anyhow!("cannot add yourself as a contact"))?;
2637    }
2638
2639    let notifications = session
2640        .db()
2641        .await
2642        .send_contact_request(requester_id, responder_id)
2643        .await?;
2644
2645    // Update outgoing contact requests of requester
2646    let mut update = proto::UpdateContacts::default();
2647    update.outgoing_requests.push(responder_id.to_proto());
2648    for connection_id in session
2649        .connection_pool()
2650        .await
2651        .user_connection_ids(requester_id)
2652    {
2653        session.peer.send(connection_id, update.clone())?;
2654    }
2655
2656    // Update incoming contact requests of responder
2657    let mut update = proto::UpdateContacts::default();
2658    update
2659        .incoming_requests
2660        .push(proto::IncomingContactRequest {
2661            requester_id: requester_id.to_proto(),
2662        });
2663    let connection_pool = session.connection_pool().await;
2664    for connection_id in connection_pool.user_connection_ids(responder_id) {
2665        session.peer.send(connection_id, update.clone())?;
2666    }
2667
2668    send_notifications(&connection_pool, &session.peer, notifications);
2669
2670    response.send(proto::Ack {})?;
2671    Ok(())
2672}
2673
2674/// Accept or decline a contact request
2675async fn respond_to_contact_request(
2676    request: proto::RespondToContactRequest,
2677    response: Response<proto::RespondToContactRequest>,
2678    session: MessageContext,
2679) -> Result<()> {
2680    let responder_id = session.user_id();
2681    let requester_id = UserId::from_proto(request.requester_id);
2682    let db = session.db().await;
2683    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2684        db.dismiss_contact_notification(responder_id, requester_id)
2685            .await?;
2686    } else {
2687        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2688
2689        let notifications = db
2690            .respond_to_contact_request(responder_id, requester_id, accept)
2691            .await?;
2692        let requester_busy = db.is_user_busy(requester_id).await?;
2693        let responder_busy = db.is_user_busy(responder_id).await?;
2694
2695        let pool = session.connection_pool().await;
2696        // Update responder with new contact
2697        let mut update = proto::UpdateContacts::default();
2698        if accept {
2699            update
2700                .contacts
2701                .push(contact_for_user(requester_id, requester_busy, &pool));
2702        }
2703        update
2704            .remove_incoming_requests
2705            .push(requester_id.to_proto());
2706        for connection_id in pool.user_connection_ids(responder_id) {
2707            session.peer.send(connection_id, update.clone())?;
2708        }
2709
2710        // Update requester with new contact
2711        let mut update = proto::UpdateContacts::default();
2712        if accept {
2713            update
2714                .contacts
2715                .push(contact_for_user(responder_id, responder_busy, &pool));
2716        }
2717        update
2718            .remove_outgoing_requests
2719            .push(responder_id.to_proto());
2720
2721        for connection_id in pool.user_connection_ids(requester_id) {
2722            session.peer.send(connection_id, update.clone())?;
2723        }
2724
2725        send_notifications(&pool, &session.peer, notifications);
2726    }
2727
2728    response.send(proto::Ack {})?;
2729    Ok(())
2730}
2731
2732/// Remove a contact.
2733async fn remove_contact(
2734    request: proto::RemoveContact,
2735    response: Response<proto::RemoveContact>,
2736    session: MessageContext,
2737) -> Result<()> {
2738    let requester_id = session.user_id();
2739    let responder_id = UserId::from_proto(request.user_id);
2740    let db = session.db().await;
2741    let (contact_accepted, deleted_notification_id) =
2742        db.remove_contact(requester_id, responder_id).await?;
2743
2744    let pool = session.connection_pool().await;
2745    // Update outgoing contact requests of requester
2746    let mut update = proto::UpdateContacts::default();
2747    if contact_accepted {
2748        update.remove_contacts.push(responder_id.to_proto());
2749    } else {
2750        update
2751            .remove_outgoing_requests
2752            .push(responder_id.to_proto());
2753    }
2754    for connection_id in pool.user_connection_ids(requester_id) {
2755        session.peer.send(connection_id, update.clone())?;
2756    }
2757
2758    // Update incoming contact requests of responder
2759    let mut update = proto::UpdateContacts::default();
2760    if contact_accepted {
2761        update.remove_contacts.push(requester_id.to_proto());
2762    } else {
2763        update
2764            .remove_incoming_requests
2765            .push(requester_id.to_proto());
2766    }
2767    for connection_id in pool.user_connection_ids(responder_id) {
2768        session.peer.send(connection_id, update.clone())?;
2769        if let Some(notification_id) = deleted_notification_id {
2770            session.peer.send(
2771                connection_id,
2772                proto::DeleteNotification {
2773                    notification_id: notification_id.to_proto(),
2774                },
2775            )?;
2776        }
2777    }
2778
2779    response.send(proto::Ack {})?;
2780    Ok(())
2781}
2782
2783fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2784    version.0.minor < 139
2785}
2786
2787async fn subscribe_to_channels(
2788    _: proto::SubscribeToChannels,
2789    session: MessageContext,
2790) -> Result<()> {
2791    subscribe_user_to_channels(session.user_id(), &session).await?;
2792    Ok(())
2793}
2794
2795async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2796    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2797    let mut pool = session.connection_pool().await;
2798    for membership in &channels_for_user.channel_memberships {
2799        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2800    }
2801    session.peer.send(
2802        session.connection_id,
2803        build_update_user_channels(&channels_for_user),
2804    )?;
2805    session.peer.send(
2806        session.connection_id,
2807        build_channels_update(channels_for_user),
2808    )?;
2809    Ok(())
2810}
2811
2812/// Creates a new channel.
2813async fn create_channel(
2814    request: proto::CreateChannel,
2815    response: Response<proto::CreateChannel>,
2816    session: MessageContext,
2817) -> Result<()> {
2818    let db = session.db().await;
2819
2820    let parent_id = request.parent_id.map(ChannelId::from_proto);
2821    let (channel, membership) = db
2822        .create_channel(&request.name, parent_id, session.user_id())
2823        .await?;
2824
2825    let root_id = channel.root_id();
2826    let channel = Channel::from_model(channel);
2827
2828    response.send(proto::CreateChannelResponse {
2829        channel: Some(channel.to_proto()),
2830        parent_id: request.parent_id,
2831    })?;
2832
2833    let mut connection_pool = session.connection_pool().await;
2834    if let Some(membership) = membership {
2835        connection_pool.subscribe_to_channel(
2836            membership.user_id,
2837            membership.channel_id,
2838            membership.role,
2839        );
2840        let update = proto::UpdateUserChannels {
2841            channel_memberships: vec![proto::ChannelMembership {
2842                channel_id: membership.channel_id.to_proto(),
2843                role: membership.role.into(),
2844            }],
2845            ..Default::default()
2846        };
2847        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2848            session.peer.send(connection_id, update.clone())?;
2849        }
2850    }
2851
2852    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2853        if !role.can_see_channel(channel.visibility) {
2854            continue;
2855        }
2856
2857        let update = proto::UpdateChannels {
2858            channels: vec![channel.to_proto()],
2859            ..Default::default()
2860        };
2861        session.peer.send(connection_id, update.clone())?;
2862    }
2863
2864    Ok(())
2865}
2866
2867/// Delete a channel
2868async fn delete_channel(
2869    request: proto::DeleteChannel,
2870    response: Response<proto::DeleteChannel>,
2871    session: MessageContext,
2872) -> Result<()> {
2873    let db = session.db().await;
2874
2875    let channel_id = request.channel_id;
2876    let (root_channel, removed_channels) = db
2877        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2878        .await?;
2879    response.send(proto::Ack {})?;
2880
2881    // Notify members of removed channels
2882    let mut update = proto::UpdateChannels::default();
2883    update
2884        .delete_channels
2885        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2886
2887    let connection_pool = session.connection_pool().await;
2888    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2889        session.peer.send(connection_id, update.clone())?;
2890    }
2891
2892    Ok(())
2893}
2894
2895/// Invite someone to join a channel.
2896async fn invite_channel_member(
2897    request: proto::InviteChannelMember,
2898    response: Response<proto::InviteChannelMember>,
2899    session: MessageContext,
2900) -> Result<()> {
2901    let db = session.db().await;
2902    let channel_id = ChannelId::from_proto(request.channel_id);
2903    let invitee_id = UserId::from_proto(request.user_id);
2904    let InviteMemberResult {
2905        channel,
2906        notifications,
2907    } = db
2908        .invite_channel_member(
2909            channel_id,
2910            invitee_id,
2911            session.user_id(),
2912            request.role().into(),
2913        )
2914        .await?;
2915
2916    let update = proto::UpdateChannels {
2917        channel_invitations: vec![channel.to_proto()],
2918        ..Default::default()
2919    };
2920
2921    let connection_pool = session.connection_pool().await;
2922    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2923        session.peer.send(connection_id, update.clone())?;
2924    }
2925
2926    send_notifications(&connection_pool, &session.peer, notifications);
2927
2928    response.send(proto::Ack {})?;
2929    Ok(())
2930}
2931
2932/// remove someone from a channel
2933async fn remove_channel_member(
2934    request: proto::RemoveChannelMember,
2935    response: Response<proto::RemoveChannelMember>,
2936    session: MessageContext,
2937) -> Result<()> {
2938    let db = session.db().await;
2939    let channel_id = ChannelId::from_proto(request.channel_id);
2940    let member_id = UserId::from_proto(request.user_id);
2941
2942    let RemoveChannelMemberResult {
2943        membership_update,
2944        notification_id,
2945    } = db
2946        .remove_channel_member(channel_id, member_id, session.user_id())
2947        .await?;
2948
2949    let mut connection_pool = session.connection_pool().await;
2950    notify_membership_updated(
2951        &mut connection_pool,
2952        membership_update,
2953        member_id,
2954        &session.peer,
2955    );
2956    for connection_id in connection_pool.user_connection_ids(member_id) {
2957        if let Some(notification_id) = notification_id {
2958            session
2959                .peer
2960                .send(
2961                    connection_id,
2962                    proto::DeleteNotification {
2963                        notification_id: notification_id.to_proto(),
2964                    },
2965                )
2966                .trace_err();
2967        }
2968    }
2969
2970    response.send(proto::Ack {})?;
2971    Ok(())
2972}
2973
2974/// Toggle the channel between public and private.
2975/// Care is taken to maintain the invariant that public channels only descend from public channels,
2976/// (though members-only channels can appear at any point in the hierarchy).
2977async fn set_channel_visibility(
2978    request: proto::SetChannelVisibility,
2979    response: Response<proto::SetChannelVisibility>,
2980    session: MessageContext,
2981) -> Result<()> {
2982    let db = session.db().await;
2983    let channel_id = ChannelId::from_proto(request.channel_id);
2984    let visibility = request.visibility().into();
2985
2986    let channel_model = db
2987        .set_channel_visibility(channel_id, visibility, session.user_id())
2988        .await?;
2989    let root_id = channel_model.root_id();
2990    let channel = Channel::from_model(channel_model);
2991
2992    let mut connection_pool = session.connection_pool().await;
2993    for (user_id, role) in connection_pool
2994        .channel_user_ids(root_id)
2995        .collect::<Vec<_>>()
2996        .into_iter()
2997    {
2998        let update = if role.can_see_channel(channel.visibility) {
2999            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3000            proto::UpdateChannels {
3001                channels: vec![channel.to_proto()],
3002                ..Default::default()
3003            }
3004        } else {
3005            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3006            proto::UpdateChannels {
3007                delete_channels: vec![channel.id.to_proto()],
3008                ..Default::default()
3009            }
3010        };
3011
3012        for connection_id in connection_pool.user_connection_ids(user_id) {
3013            session.peer.send(connection_id, update.clone())?;
3014        }
3015    }
3016
3017    response.send(proto::Ack {})?;
3018    Ok(())
3019}
3020
3021/// Alter the role for a user in the channel.
3022async fn set_channel_member_role(
3023    request: proto::SetChannelMemberRole,
3024    response: Response<proto::SetChannelMemberRole>,
3025    session: MessageContext,
3026) -> Result<()> {
3027    let db = session.db().await;
3028    let channel_id = ChannelId::from_proto(request.channel_id);
3029    let member_id = UserId::from_proto(request.user_id);
3030    let result = db
3031        .set_channel_member_role(
3032            channel_id,
3033            session.user_id(),
3034            member_id,
3035            request.role().into(),
3036        )
3037        .await?;
3038
3039    match result {
3040        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3041            let mut connection_pool = session.connection_pool().await;
3042            notify_membership_updated(
3043                &mut connection_pool,
3044                membership_update,
3045                member_id,
3046                &session.peer,
3047            )
3048        }
3049        db::SetMemberRoleResult::InviteUpdated(channel) => {
3050            let update = proto::UpdateChannels {
3051                channel_invitations: vec![channel.to_proto()],
3052                ..Default::default()
3053            };
3054
3055            for connection_id in session
3056                .connection_pool()
3057                .await
3058                .user_connection_ids(member_id)
3059            {
3060                session.peer.send(connection_id, update.clone())?;
3061            }
3062        }
3063    }
3064
3065    response.send(proto::Ack {})?;
3066    Ok(())
3067}
3068
3069/// Change the name of a channel
3070async fn rename_channel(
3071    request: proto::RenameChannel,
3072    response: Response<proto::RenameChannel>,
3073    session: MessageContext,
3074) -> Result<()> {
3075    let db = session.db().await;
3076    let channel_id = ChannelId::from_proto(request.channel_id);
3077    let channel_model = db
3078        .rename_channel(channel_id, session.user_id(), &request.name)
3079        .await?;
3080    let root_id = channel_model.root_id();
3081    let channel = Channel::from_model(channel_model);
3082
3083    response.send(proto::RenameChannelResponse {
3084        channel: Some(channel.to_proto()),
3085    })?;
3086
3087    let connection_pool = session.connection_pool().await;
3088    let update = proto::UpdateChannels {
3089        channels: vec![channel.to_proto()],
3090        ..Default::default()
3091    };
3092    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3093        if role.can_see_channel(channel.visibility) {
3094            session.peer.send(connection_id, update.clone())?;
3095        }
3096    }
3097
3098    Ok(())
3099}
3100
3101/// Move a channel to a new parent.
3102async fn move_channel(
3103    request: proto::MoveChannel,
3104    response: Response<proto::MoveChannel>,
3105    session: MessageContext,
3106) -> Result<()> {
3107    let channel_id = ChannelId::from_proto(request.channel_id);
3108    let to = ChannelId::from_proto(request.to);
3109
3110    let (root_id, channels) = session
3111        .db()
3112        .await
3113        .move_channel(channel_id, to, session.user_id())
3114        .await?;
3115
3116    let connection_pool = session.connection_pool().await;
3117    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3118        let channels = channels
3119            .iter()
3120            .filter_map(|channel| {
3121                if role.can_see_channel(channel.visibility) {
3122                    Some(channel.to_proto())
3123                } else {
3124                    None
3125                }
3126            })
3127            .collect::<Vec<_>>();
3128        if channels.is_empty() {
3129            continue;
3130        }
3131
3132        let update = proto::UpdateChannels {
3133            channels,
3134            ..Default::default()
3135        };
3136
3137        session.peer.send(connection_id, update.clone())?;
3138    }
3139
3140    response.send(Ack {})?;
3141    Ok(())
3142}
3143
3144async fn reorder_channel(
3145    request: proto::ReorderChannel,
3146    response: Response<proto::ReorderChannel>,
3147    session: MessageContext,
3148) -> Result<()> {
3149    let channel_id = ChannelId::from_proto(request.channel_id);
3150    let direction = request.direction();
3151
3152    let updated_channels = session
3153        .db()
3154        .await
3155        .reorder_channel(channel_id, direction, session.user_id())
3156        .await?;
3157
3158    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3159        let connection_pool = session.connection_pool().await;
3160        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3161            let channels = updated_channels
3162                .iter()
3163                .filter_map(|channel| {
3164                    if role.can_see_channel(channel.visibility) {
3165                        Some(channel.to_proto())
3166                    } else {
3167                        None
3168                    }
3169                })
3170                .collect::<Vec<_>>();
3171
3172            if channels.is_empty() {
3173                continue;
3174            }
3175
3176            let update = proto::UpdateChannels {
3177                channels,
3178                ..Default::default()
3179            };
3180
3181            session.peer.send(connection_id, update.clone())?;
3182        }
3183    }
3184
3185    response.send(Ack {})?;
3186    Ok(())
3187}
3188
3189/// Get the list of channel members
3190async fn get_channel_members(
3191    request: proto::GetChannelMembers,
3192    response: Response<proto::GetChannelMembers>,
3193    session: MessageContext,
3194) -> Result<()> {
3195    let db = session.db().await;
3196    let channel_id = ChannelId::from_proto(request.channel_id);
3197    let limit = if request.limit == 0 {
3198        u16::MAX as u64
3199    } else {
3200        request.limit
3201    };
3202    let (members, users) = db
3203        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3204        .await?;
3205    response.send(proto::GetChannelMembersResponse { members, users })?;
3206    Ok(())
3207}
3208
3209/// Accept or decline a channel invitation.
3210async fn respond_to_channel_invite(
3211    request: proto::RespondToChannelInvite,
3212    response: Response<proto::RespondToChannelInvite>,
3213    session: MessageContext,
3214) -> Result<()> {
3215    let db = session.db().await;
3216    let channel_id = ChannelId::from_proto(request.channel_id);
3217    let RespondToChannelInvite {
3218        membership_update,
3219        notifications,
3220    } = db
3221        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3222        .await?;
3223
3224    let mut connection_pool = session.connection_pool().await;
3225    if let Some(membership_update) = membership_update {
3226        notify_membership_updated(
3227            &mut connection_pool,
3228            membership_update,
3229            session.user_id(),
3230            &session.peer,
3231        );
3232    } else {
3233        let update = proto::UpdateChannels {
3234            remove_channel_invitations: vec![channel_id.to_proto()],
3235            ..Default::default()
3236        };
3237
3238        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3239            session.peer.send(connection_id, update.clone())?;
3240        }
3241    };
3242
3243    send_notifications(&connection_pool, &session.peer, notifications);
3244
3245    response.send(proto::Ack {})?;
3246
3247    Ok(())
3248}
3249
3250/// Join the channels' room
3251async fn join_channel(
3252    request: proto::JoinChannel,
3253    response: Response<proto::JoinChannel>,
3254    session: MessageContext,
3255) -> Result<()> {
3256    let channel_id = ChannelId::from_proto(request.channel_id);
3257    join_channel_internal(channel_id, Box::new(response), session).await
3258}
3259
3260trait JoinChannelInternalResponse {
3261    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3262}
3263impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3264    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3265        Response::<proto::JoinChannel>::send(self, result)
3266    }
3267}
3268impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3269    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3270        Response::<proto::JoinRoom>::send(self, result)
3271    }
3272}
3273
3274async fn join_channel_internal(
3275    channel_id: ChannelId,
3276    response: Box<impl JoinChannelInternalResponse>,
3277    session: MessageContext,
3278) -> Result<()> {
3279    let joined_room = {
3280        let mut db = session.db().await;
3281        // If zed quits without leaving the room, and the user re-opens zed before the
3282        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3283        // room they were in.
3284        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3285            tracing::info!(
3286                stale_connection_id = %connection,
3287                "cleaning up stale connection",
3288            );
3289            drop(db);
3290            leave_room_for_session(&session, connection).await?;
3291            db = session.db().await;
3292        }
3293
3294        let (joined_room, membership_updated, role) = db
3295            .join_channel(channel_id, session.user_id(), session.connection_id)
3296            .await?;
3297
3298        let live_kit_connection_info =
3299            session
3300                .app_state
3301                .livekit_client
3302                .as_ref()
3303                .and_then(|live_kit| {
3304                    let (can_publish, token) = if role == ChannelRole::Guest {
3305                        (
3306                            false,
3307                            live_kit
3308                                .guest_token(
3309                                    &joined_room.room.livekit_room,
3310                                    &session.user_id().to_string(),
3311                                )
3312                                .trace_err()?,
3313                        )
3314                    } else {
3315                        (
3316                            true,
3317                            live_kit
3318                                .room_token(
3319                                    &joined_room.room.livekit_room,
3320                                    &session.user_id().to_string(),
3321                                )
3322                                .trace_err()?,
3323                        )
3324                    };
3325
3326                    Some(LiveKitConnectionInfo {
3327                        server_url: live_kit.url().into(),
3328                        token,
3329                        can_publish,
3330                    })
3331                });
3332
3333        response.send(proto::JoinRoomResponse {
3334            room: Some(joined_room.room.clone()),
3335            channel_id: joined_room
3336                .channel
3337                .as_ref()
3338                .map(|channel| channel.id.to_proto()),
3339            live_kit_connection_info,
3340        })?;
3341
3342        let mut connection_pool = session.connection_pool().await;
3343        if let Some(membership_updated) = membership_updated {
3344            notify_membership_updated(
3345                &mut connection_pool,
3346                membership_updated,
3347                session.user_id(),
3348                &session.peer,
3349            );
3350        }
3351
3352        room_updated(&joined_room.room, &session.peer);
3353
3354        joined_room
3355    };
3356
3357    channel_updated(
3358        &joined_room.channel.context("channel not returned")?,
3359        &joined_room.room,
3360        &session.peer,
3361        &*session.connection_pool().await,
3362    );
3363
3364    update_user_contacts(session.user_id(), &session).await?;
3365    Ok(())
3366}
3367
3368/// Start editing the channel notes
3369async fn join_channel_buffer(
3370    request: proto::JoinChannelBuffer,
3371    response: Response<proto::JoinChannelBuffer>,
3372    session: MessageContext,
3373) -> Result<()> {
3374    let db = session.db().await;
3375    let channel_id = ChannelId::from_proto(request.channel_id);
3376
3377    let open_response = db
3378        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3379        .await?;
3380
3381    let collaborators = open_response.collaborators.clone();
3382    response.send(open_response)?;
3383
3384    let update = UpdateChannelBufferCollaborators {
3385        channel_id: channel_id.to_proto(),
3386        collaborators: collaborators.clone(),
3387    };
3388    channel_buffer_updated(
3389        session.connection_id,
3390        collaborators
3391            .iter()
3392            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3393        &update,
3394        &session.peer,
3395    );
3396
3397    Ok(())
3398}
3399
3400/// Edit the channel notes
3401async fn update_channel_buffer(
3402    request: proto::UpdateChannelBuffer,
3403    session: MessageContext,
3404) -> Result<()> {
3405    let db = session.db().await;
3406    let channel_id = ChannelId::from_proto(request.channel_id);
3407
3408    let (collaborators, epoch, version) = db
3409        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3410        .await?;
3411
3412    channel_buffer_updated(
3413        session.connection_id,
3414        collaborators.clone(),
3415        &proto::UpdateChannelBuffer {
3416            channel_id: channel_id.to_proto(),
3417            operations: request.operations,
3418        },
3419        &session.peer,
3420    );
3421
3422    let pool = &*session.connection_pool().await;
3423
3424    let non_collaborators =
3425        pool.channel_connection_ids(channel_id)
3426            .filter_map(|(connection_id, _)| {
3427                if collaborators.contains(&connection_id) {
3428                    None
3429                } else {
3430                    Some(connection_id)
3431                }
3432            });
3433
3434    broadcast(None, non_collaborators, |peer_id| {
3435        session.peer.send(
3436            peer_id,
3437            proto::UpdateChannels {
3438                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3439                    channel_id: channel_id.to_proto(),
3440                    epoch: epoch as u64,
3441                    version: version.clone(),
3442                }],
3443                ..Default::default()
3444            },
3445        )
3446    });
3447
3448    Ok(())
3449}
3450
3451/// Rejoin the channel notes after a connection blip
3452async fn rejoin_channel_buffers(
3453    request: proto::RejoinChannelBuffers,
3454    response: Response<proto::RejoinChannelBuffers>,
3455    session: MessageContext,
3456) -> Result<()> {
3457    let db = session.db().await;
3458    let buffers = db
3459        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3460        .await?;
3461
3462    for rejoined_buffer in &buffers {
3463        let collaborators_to_notify = rejoined_buffer
3464            .buffer
3465            .collaborators
3466            .iter()
3467            .filter_map(|c| Some(c.peer_id?.into()));
3468        channel_buffer_updated(
3469            session.connection_id,
3470            collaborators_to_notify,
3471            &proto::UpdateChannelBufferCollaborators {
3472                channel_id: rejoined_buffer.buffer.channel_id,
3473                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3474            },
3475            &session.peer,
3476        );
3477    }
3478
3479    response.send(proto::RejoinChannelBuffersResponse {
3480        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3481    })?;
3482
3483    Ok(())
3484}
3485
3486/// Stop editing the channel notes
3487async fn leave_channel_buffer(
3488    request: proto::LeaveChannelBuffer,
3489    response: Response<proto::LeaveChannelBuffer>,
3490    session: MessageContext,
3491) -> Result<()> {
3492    let db = session.db().await;
3493    let channel_id = ChannelId::from_proto(request.channel_id);
3494
3495    let left_buffer = db
3496        .leave_channel_buffer(channel_id, session.connection_id)
3497        .await?;
3498
3499    response.send(Ack {})?;
3500
3501    channel_buffer_updated(
3502        session.connection_id,
3503        left_buffer.connections,
3504        &proto::UpdateChannelBufferCollaborators {
3505            channel_id: channel_id.to_proto(),
3506            collaborators: left_buffer.collaborators,
3507        },
3508        &session.peer,
3509    );
3510
3511    Ok(())
3512}
3513
3514fn channel_buffer_updated<T: EnvelopedMessage>(
3515    sender_id: ConnectionId,
3516    collaborators: impl IntoIterator<Item = ConnectionId>,
3517    message: &T,
3518    peer: &Peer,
3519) {
3520    broadcast(Some(sender_id), collaborators, |peer_id| {
3521        peer.send(peer_id, message.clone())
3522    });
3523}
3524
3525fn send_notifications(
3526    connection_pool: &ConnectionPool,
3527    peer: &Peer,
3528    notifications: db::NotificationBatch,
3529) {
3530    for (user_id, notification) in notifications {
3531        for connection_id in connection_pool.user_connection_ids(user_id) {
3532            if let Err(error) = peer.send(
3533                connection_id,
3534                proto::AddNotification {
3535                    notification: Some(notification.clone()),
3536                },
3537            ) {
3538                tracing::error!(
3539                    "failed to send notification to {:?} {}",
3540                    connection_id,
3541                    error
3542                );
3543            }
3544        }
3545    }
3546}
3547
3548/// Send a message to the channel
3549async fn send_channel_message(
3550    _request: proto::SendChannelMessage,
3551    _response: Response<proto::SendChannelMessage>,
3552    _session: MessageContext,
3553) -> Result<()> {
3554    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3555}
3556
3557/// Delete a channel message
3558async fn remove_channel_message(
3559    _request: proto::RemoveChannelMessage,
3560    _response: Response<proto::RemoveChannelMessage>,
3561    _session: MessageContext,
3562) -> Result<()> {
3563    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3564}
3565
3566async fn update_channel_message(
3567    _request: proto::UpdateChannelMessage,
3568    _response: Response<proto::UpdateChannelMessage>,
3569    _session: MessageContext,
3570) -> Result<()> {
3571    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3572}
3573
3574/// Mark a channel message as read
3575async fn acknowledge_channel_message(
3576    _request: proto::AckChannelMessage,
3577    _session: MessageContext,
3578) -> Result<()> {
3579    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3580}
3581
3582/// Mark a buffer version as synced
3583async fn acknowledge_buffer_version(
3584    request: proto::AckBufferOperation,
3585    session: MessageContext,
3586) -> Result<()> {
3587    let buffer_id = BufferId::from_proto(request.buffer_id);
3588    session
3589        .db()
3590        .await
3591        .observe_buffer_version(
3592            buffer_id,
3593            session.user_id(),
3594            request.epoch as i32,
3595            &request.version,
3596        )
3597        .await?;
3598    Ok(())
3599}
3600
3601/// Start receiving chat updates for a channel
3602async fn join_channel_chat(
3603    _request: proto::JoinChannelChat,
3604    _response: Response<proto::JoinChannelChat>,
3605    _session: MessageContext,
3606) -> Result<()> {
3607    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3608}
3609
3610/// Stop receiving chat updates for a channel
3611async fn leave_channel_chat(
3612    _request: proto::LeaveChannelChat,
3613    _session: MessageContext,
3614) -> Result<()> {
3615    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3616}
3617
3618/// Retrieve the chat history for a channel
3619async fn get_channel_messages(
3620    _request: proto::GetChannelMessages,
3621    _response: Response<proto::GetChannelMessages>,
3622    _session: MessageContext,
3623) -> Result<()> {
3624    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3625}
3626
3627/// Retrieve specific chat messages
3628async fn get_channel_messages_by_id(
3629    _request: proto::GetChannelMessagesById,
3630    _response: Response<proto::GetChannelMessagesById>,
3631    _session: MessageContext,
3632) -> Result<()> {
3633    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3634}
3635
3636/// Retrieve the current users notifications
3637async fn get_notifications(
3638    request: proto::GetNotifications,
3639    response: Response<proto::GetNotifications>,
3640    session: MessageContext,
3641) -> Result<()> {
3642    let notifications = session
3643        .db()
3644        .await
3645        .get_notifications(
3646            session.user_id(),
3647            NOTIFICATION_COUNT_PER_PAGE,
3648            request.before_id.map(db::NotificationId::from_proto),
3649        )
3650        .await?;
3651    response.send(proto::GetNotificationsResponse {
3652        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3653        notifications,
3654    })?;
3655    Ok(())
3656}
3657
3658/// Mark notifications as read
3659async fn mark_notification_as_read(
3660    request: proto::MarkNotificationRead,
3661    response: Response<proto::MarkNotificationRead>,
3662    session: MessageContext,
3663) -> Result<()> {
3664    let database = &session.db().await;
3665    let notifications = database
3666        .mark_notification_as_read_by_id(
3667            session.user_id(),
3668            NotificationId::from_proto(request.notification_id),
3669        )
3670        .await?;
3671    send_notifications(
3672        &*session.connection_pool().await,
3673        &session.peer,
3674        notifications,
3675    );
3676    response.send(proto::Ack {})?;
3677    Ok(())
3678}
3679
3680fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3681    let message = match message {
3682        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3683        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3684        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3685        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3686        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3687            code: frame.code.into(),
3688            reason: frame.reason.as_str().to_owned().into(),
3689        })),
3690        // We should never receive a frame while reading the message, according
3691        // to the `tungstenite` maintainers:
3692        //
3693        // > It cannot occur when you read messages from the WebSocket, but it
3694        // > can be used when you want to send the raw frames (e.g. you want to
3695        // > send the frames to the WebSocket without composing the full message first).
3696        // >
3697        // > — https://github.com/snapview/tungstenite-rs/issues/268
3698        TungsteniteMessage::Frame(_) => {
3699            bail!("received an unexpected frame while reading the message")
3700        }
3701    };
3702
3703    Ok(message)
3704}
3705
3706fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3707    match message {
3708        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3709        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3710        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3711        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3712        AxumMessage::Close(frame) => {
3713            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3714                code: frame.code.into(),
3715                reason: frame.reason.as_ref().into(),
3716            }))
3717        }
3718    }
3719}
3720
3721fn notify_membership_updated(
3722    connection_pool: &mut ConnectionPool,
3723    result: MembershipUpdated,
3724    user_id: UserId,
3725    peer: &Peer,
3726) {
3727    for membership in &result.new_channels.channel_memberships {
3728        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3729    }
3730    for channel_id in &result.removed_channels {
3731        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3732    }
3733
3734    let user_channels_update = proto::UpdateUserChannels {
3735        channel_memberships: result
3736            .new_channels
3737            .channel_memberships
3738            .iter()
3739            .map(|cm| proto::ChannelMembership {
3740                channel_id: cm.channel_id.to_proto(),
3741                role: cm.role.into(),
3742            })
3743            .collect(),
3744        ..Default::default()
3745    };
3746
3747    let mut update = build_channels_update(result.new_channels);
3748    update.delete_channels = result
3749        .removed_channels
3750        .into_iter()
3751        .map(|id| id.to_proto())
3752        .collect();
3753    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3754
3755    for connection_id in connection_pool.user_connection_ids(user_id) {
3756        peer.send(connection_id, user_channels_update.clone())
3757            .trace_err();
3758        peer.send(connection_id, update.clone()).trace_err();
3759    }
3760}
3761
3762fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3763    proto::UpdateUserChannels {
3764        channel_memberships: channels
3765            .channel_memberships
3766            .iter()
3767            .map(|m| proto::ChannelMembership {
3768                channel_id: m.channel_id.to_proto(),
3769                role: m.role.into(),
3770            })
3771            .collect(),
3772        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3773    }
3774}
3775
3776fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3777    let mut update = proto::UpdateChannels::default();
3778
3779    for channel in channels.channels {
3780        update.channels.push(channel.to_proto());
3781    }
3782
3783    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3784
3785    for (channel_id, participants) in channels.channel_participants {
3786        update
3787            .channel_participants
3788            .push(proto::ChannelParticipants {
3789                channel_id: channel_id.to_proto(),
3790                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3791            });
3792    }
3793
3794    for channel in channels.invited_channels {
3795        update.channel_invitations.push(channel.to_proto());
3796    }
3797
3798    update
3799}
3800
3801fn build_initial_contacts_update(
3802    contacts: Vec<db::Contact>,
3803    pool: &ConnectionPool,
3804) -> proto::UpdateContacts {
3805    let mut update = proto::UpdateContacts::default();
3806
3807    for contact in contacts {
3808        match contact {
3809            db::Contact::Accepted { user_id, busy } => {
3810                update.contacts.push(contact_for_user(user_id, busy, pool));
3811            }
3812            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3813            db::Contact::Incoming { user_id } => {
3814                update
3815                    .incoming_requests
3816                    .push(proto::IncomingContactRequest {
3817                        requester_id: user_id.to_proto(),
3818                    })
3819            }
3820        }
3821    }
3822
3823    update
3824}
3825
3826fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3827    proto::Contact {
3828        user_id: user_id.to_proto(),
3829        online: pool.is_user_online(user_id),
3830        busy,
3831    }
3832}
3833
3834fn room_updated(room: &proto::Room, peer: &Peer) {
3835    broadcast(
3836        None,
3837        room.participants
3838            .iter()
3839            .filter_map(|participant| Some(participant.peer_id?.into())),
3840        |peer_id| {
3841            peer.send(
3842                peer_id,
3843                proto::RoomUpdated {
3844                    room: Some(room.clone()),
3845                },
3846            )
3847        },
3848    );
3849}
3850
3851fn channel_updated(
3852    channel: &db::channel::Model,
3853    room: &proto::Room,
3854    peer: &Peer,
3855    pool: &ConnectionPool,
3856) {
3857    let participants = room
3858        .participants
3859        .iter()
3860        .map(|p| p.user_id)
3861        .collect::<Vec<_>>();
3862
3863    broadcast(
3864        None,
3865        pool.channel_connection_ids(channel.root_id())
3866            .filter_map(|(channel_id, role)| {
3867                role.can_see_channel(channel.visibility)
3868                    .then_some(channel_id)
3869            }),
3870        |peer_id| {
3871            peer.send(
3872                peer_id,
3873                proto::UpdateChannels {
3874                    channel_participants: vec![proto::ChannelParticipants {
3875                        channel_id: channel.id.to_proto(),
3876                        participant_user_ids: participants.clone(),
3877                    }],
3878                    ..Default::default()
3879                },
3880            )
3881        },
3882    );
3883}
3884
3885async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3886    let db = session.db().await;
3887
3888    let contacts = db.get_contacts(user_id).await?;
3889    let busy = db.is_user_busy(user_id).await?;
3890
3891    let pool = session.connection_pool().await;
3892    let updated_contact = contact_for_user(user_id, busy, &pool);
3893    for contact in contacts {
3894        if let db::Contact::Accepted {
3895            user_id: contact_user_id,
3896            ..
3897        } = contact
3898        {
3899            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3900                session
3901                    .peer
3902                    .send(
3903                        contact_conn_id,
3904                        proto::UpdateContacts {
3905                            contacts: vec![updated_contact.clone()],
3906                            remove_contacts: Default::default(),
3907                            incoming_requests: Default::default(),
3908                            remove_incoming_requests: Default::default(),
3909                            outgoing_requests: Default::default(),
3910                            remove_outgoing_requests: Default::default(),
3911                        },
3912                    )
3913                    .trace_err();
3914            }
3915        }
3916    }
3917    Ok(())
3918}
3919
3920async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3921    let mut contacts_to_update = HashSet::default();
3922
3923    let room_id;
3924    let canceled_calls_to_user_ids;
3925    let livekit_room;
3926    let delete_livekit_room;
3927    let room;
3928    let channel;
3929
3930    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3931        contacts_to_update.insert(session.user_id());
3932
3933        for project in left_room.left_projects.values() {
3934            project_left(project, session);
3935        }
3936
3937        room_id = RoomId::from_proto(left_room.room.id);
3938        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3939        livekit_room = mem::take(&mut left_room.room.livekit_room);
3940        delete_livekit_room = left_room.deleted;
3941        room = mem::take(&mut left_room.room);
3942        channel = mem::take(&mut left_room.channel);
3943
3944        room_updated(&room, &session.peer);
3945    } else {
3946        return Ok(());
3947    }
3948
3949    if let Some(channel) = channel {
3950        channel_updated(
3951            &channel,
3952            &room,
3953            &session.peer,
3954            &*session.connection_pool().await,
3955        );
3956    }
3957
3958    {
3959        let pool = session.connection_pool().await;
3960        for canceled_user_id in canceled_calls_to_user_ids {
3961            for connection_id in pool.user_connection_ids(canceled_user_id) {
3962                session
3963                    .peer
3964                    .send(
3965                        connection_id,
3966                        proto::CallCanceled {
3967                            room_id: room_id.to_proto(),
3968                        },
3969                    )
3970                    .trace_err();
3971            }
3972            contacts_to_update.insert(canceled_user_id);
3973        }
3974    }
3975
3976    for contact_user_id in contacts_to_update {
3977        update_user_contacts(contact_user_id, session).await?;
3978    }
3979
3980    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3981        live_kit
3982            .remove_participant(livekit_room.clone(), session.user_id().to_string())
3983            .await
3984            .trace_err();
3985
3986        if delete_livekit_room {
3987            live_kit.delete_room(livekit_room).await.trace_err();
3988        }
3989    }
3990
3991    Ok(())
3992}
3993
3994async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3995    let left_channel_buffers = session
3996        .db()
3997        .await
3998        .leave_channel_buffers(session.connection_id)
3999        .await?;
4000
4001    for left_buffer in left_channel_buffers {
4002        channel_buffer_updated(
4003            session.connection_id,
4004            left_buffer.connections,
4005            &proto::UpdateChannelBufferCollaborators {
4006                channel_id: left_buffer.channel_id.to_proto(),
4007                collaborators: left_buffer.collaborators,
4008            },
4009            &session.peer,
4010        );
4011    }
4012
4013    Ok(())
4014}
4015
4016fn project_left(project: &db::LeftProject, session: &Session) {
4017    for connection_id in &project.connection_ids {
4018        if project.should_unshare {
4019            session
4020                .peer
4021                .send(
4022                    *connection_id,
4023                    proto::UnshareProject {
4024                        project_id: project.id.to_proto(),
4025                    },
4026                )
4027                .trace_err();
4028        } else {
4029            session
4030                .peer
4031                .send(
4032                    *connection_id,
4033                    proto::RemoveProjectCollaborator {
4034                        project_id: project.id.to_proto(),
4035                        peer_id: Some(session.connection_id.into()),
4036                    },
4037                )
4038                .trace_err();
4039        }
4040    }
4041}
4042
4043async fn share_agent_thread(
4044    request: proto::ShareAgentThread,
4045    response: Response<proto::ShareAgentThread>,
4046    session: MessageContext,
4047) -> Result<()> {
4048    let user_id = session.user_id();
4049
4050    let share_id = SharedThreadId::from_proto(request.session_id.clone())
4051        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4052
4053    session
4054        .db()
4055        .await
4056        .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4057        .await?;
4058
4059    response.send(proto::Ack {})?;
4060
4061    Ok(())
4062}
4063
4064async fn get_shared_agent_thread(
4065    request: proto::GetSharedAgentThread,
4066    response: Response<proto::GetSharedAgentThread>,
4067    session: MessageContext,
4068) -> Result<()> {
4069    let share_id = SharedThreadId::from_proto(request.session_id)
4070        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4071
4072    let result = session.db().await.get_shared_thread(share_id).await?;
4073
4074    match result {
4075        Some((thread, username)) => {
4076            response.send(proto::GetSharedAgentThreadResponse {
4077                title: thread.title,
4078                thread_data: thread.data,
4079                sharer_username: username,
4080                created_at: thread.created_at.and_utc().to_rfc3339(),
4081            })?;
4082        }
4083        None => {
4084            return Err(anyhow!("Shared thread not found").into());
4085        }
4086    }
4087
4088    Ok(())
4089}
4090
4091pub trait ResultExt {
4092    type Ok;
4093
4094    fn trace_err(self) -> Option<Self::Ok>;
4095}
4096
4097impl<T, E> ResultExt for Result<T, E>
4098where
4099    E: std::fmt::Debug,
4100{
4101    type Ok = T;
4102
4103    #[track_caller]
4104    fn trace_err(self) -> Option<T> {
4105        match self {
4106            Ok(value) => Some(value),
4107            Err(error) => {
4108                tracing::error!("{:?}", error);
4109                None
4110            }
4111        }
4112    }
4113}