rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
  10    },
  11    executor::Executor,
  12};
  13use anyhow::{Context as _, anyhow, bail};
  14use async_tungstenite::tungstenite::{
  15    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  16};
  17use axum::headers::UserAgent;
  18use axum::{
  19    Extension, Router, TypedHeader,
  20    body::Body,
  21    extract::{
  22        ConnectInfo, WebSocketUpgrade,
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use core::fmt::{self, Debug, Formatter};
  34use futures::TryFutureExt as _;
  35use rpc::proto::split_repository_update;
  36use tracing::Span;
  37use util::paths::PathStyle;
  38
  39use futures::{
  40    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  41    stream::FuturesUnordered,
  42};
  43use prometheus::{IntGauge, register_int_gauge};
  44use rpc::{
  45    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  46    proto::{
  47        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  48        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  49    },
  50};
  51use semver::Version;
  52use serde::{Serialize, Serializer};
  53use std::{
  54    any::TypeId,
  55    future::Future,
  56    marker::PhantomData,
  57    mem,
  58    net::SocketAddr,
  59    ops::{Deref, DerefMut},
  60    rc::Rc,
  61    sync::{
  62        Arc, OnceLock,
  63        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  64    },
  65    time::{Duration, Instant},
  66};
  67use tokio::sync::{Semaphore, watch};
  68use tower::ServiceBuilder;
  69use tracing::{
  70    Instrument,
  71    field::{self},
  72    info_span, instrument,
  73};
  74
  75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  76
  77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  79
  80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  82
  83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  84
  85const TOTAL_DURATION_MS: &str = "total_duration_ms";
  86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  88const HOST_WAITING_MS: &str = "host_waiting_ms";
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  92
  93pub struct ConnectionGuard;
  94
  95impl ConnectionGuard {
  96    pub fn try_acquire() -> Result<Self, ()> {
  97        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  98        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
  99            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 100            tracing::error!(
 101                "too many concurrent connections: {}",
 102                current_connections + 1
 103            );
 104            return Err(());
 105        }
 106        Ok(ConnectionGuard)
 107    }
 108}
 109
 110impl Drop for ConnectionGuard {
 111    fn drop(&mut self) {
 112        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 113    }
 114}
 115
 116struct Response<R> {
 117    peer: Arc<Peer>,
 118    receipt: Receipt<R>,
 119    responded: Arc<AtomicBool>,
 120}
 121
 122impl<R: RequestMessage> Response<R> {
 123    fn send(self, payload: R::Response) -> Result<()> {
 124        self.responded.store(true, SeqCst);
 125        self.peer.respond(self.receipt, payload)?;
 126        Ok(())
 127    }
 128}
 129
 130#[derive(Clone, Debug)]
 131pub enum Principal {
 132    User(User),
 133    Impersonated { user: User, admin: User },
 134}
 135
 136impl Principal {
 137    fn update_span(&self, span: &tracing::Span) {
 138        match &self {
 139            Principal::User(user) => {
 140                span.record("user_id", user.id.0);
 141                span.record("login", &user.github_login);
 142            }
 143            Principal::Impersonated { user, admin } => {
 144                span.record("user_id", user.id.0);
 145                span.record("login", &user.github_login);
 146                span.record("impersonator", &admin.github_login);
 147            }
 148        }
 149    }
 150}
 151
 152#[derive(Clone)]
 153struct MessageContext {
 154    session: Session,
 155    span: tracing::Span,
 156}
 157
 158impl Deref for MessageContext {
 159    type Target = Session;
 160
 161    fn deref(&self) -> &Self::Target {
 162        &self.session
 163    }
 164}
 165
 166impl MessageContext {
 167    pub fn forward_request<T: RequestMessage>(
 168        &self,
 169        receiver_id: ConnectionId,
 170        request: T,
 171    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 172        let request_start_time = Instant::now();
 173        let span = self.span.clone();
 174        tracing::info!("start forwarding request");
 175        self.peer
 176            .forward_request(self.connection_id, receiver_id, request)
 177            .inspect(move |_| {
 178                span.record(
 179                    HOST_WAITING_MS,
 180                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 181                );
 182            })
 183            .inspect_err(|_| tracing::error!("error forwarding request"))
 184            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 185    }
 186}
 187
 188#[derive(Clone)]
 189struct Session {
 190    principal: Principal,
 191    connection_id: ConnectionId,
 192    db: Arc<tokio::sync::Mutex<DbHandle>>,
 193    peer: Arc<Peer>,
 194    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 195    app_state: Arc<AppState>,
 196    /// The GeoIP country code for the user.
 197    #[allow(unused)]
 198    geoip_country_code: Option<String>,
 199    #[allow(unused)]
 200    system_id: Option<String>,
 201    _executor: Executor,
 202}
 203
 204impl Session {
 205    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 206        #[cfg(test)]
 207        tokio::task::yield_now().await;
 208        let guard = self.db.lock().await;
 209        #[cfg(test)]
 210        tokio::task::yield_now().await;
 211        guard
 212    }
 213
 214    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 215        #[cfg(test)]
 216        tokio::task::yield_now().await;
 217        let guard = self.connection_pool.lock();
 218        ConnectionPoolGuard {
 219            guard,
 220            _not_send: PhantomData,
 221        }
 222    }
 223
 224    #[expect(dead_code)]
 225    fn is_staff(&self) -> bool {
 226        match &self.principal {
 227            Principal::User(user) => user.admin,
 228            Principal::Impersonated { .. } => true,
 229        }
 230    }
 231
 232    fn user_id(&self) -> UserId {
 233        match &self.principal {
 234            Principal::User(user) => user.id,
 235            Principal::Impersonated { user, .. } => user.id,
 236        }
 237    }
 238}
 239
 240impl Debug for Session {
 241    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 242        let mut result = f.debug_struct("Session");
 243        match &self.principal {
 244            Principal::User(user) => {
 245                result.field("user", &user.github_login);
 246            }
 247            Principal::Impersonated { user, admin } => {
 248                result.field("user", &user.github_login);
 249                result.field("impersonator", &admin.github_login);
 250            }
 251        }
 252        result.field("connection_id", &self.connection_id).finish()
 253    }
 254}
 255
 256struct DbHandle(Arc<Database>);
 257
 258impl Deref for DbHandle {
 259    type Target = Database;
 260
 261    fn deref(&self) -> &Self::Target {
 262        self.0.as_ref()
 263    }
 264}
 265
 266pub struct Server {
 267    id: parking_lot::Mutex<ServerId>,
 268    peer: Arc<Peer>,
 269    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 270    app_state: Arc<AppState>,
 271    handlers: HashMap<TypeId, MessageHandler>,
 272    teardown: watch::Sender<bool>,
 273}
 274
 275pub(crate) struct ConnectionPoolGuard<'a> {
 276    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 277    _not_send: PhantomData<Rc<()>>,
 278}
 279
 280#[derive(Serialize)]
 281pub struct ServerSnapshot<'a> {
 282    peer: &'a Peer,
 283    #[serde(serialize_with = "serialize_deref")]
 284    connection_pool: ConnectionPoolGuard<'a>,
 285}
 286
 287pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 288where
 289    S: Serializer,
 290    T: Deref<Target = U>,
 291    U: Serialize,
 292{
 293    Serialize::serialize(value.deref(), serializer)
 294}
 295
 296impl Server {
 297    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 298        let mut server = Self {
 299            id: parking_lot::Mutex::new(id),
 300            peer: Peer::new(id.0 as u32),
 301            app_state,
 302            connection_pool: Default::default(),
 303            handlers: Default::default(),
 304            teardown: watch::channel(false).0,
 305        };
 306
 307        server
 308            .add_request_handler(ping)
 309            .add_request_handler(create_room)
 310            .add_request_handler(join_room)
 311            .add_request_handler(rejoin_room)
 312            .add_request_handler(leave_room)
 313            .add_request_handler(set_room_participant_role)
 314            .add_request_handler(call)
 315            .add_request_handler(cancel_call)
 316            .add_message_handler(decline_call)
 317            .add_request_handler(update_participant_location)
 318            .add_request_handler(share_project)
 319            .add_message_handler(unshare_project)
 320            .add_request_handler(join_project)
 321            .add_message_handler(leave_project)
 322            .add_request_handler(update_project)
 323            .add_request_handler(update_worktree)
 324            .add_request_handler(update_repository)
 325            .add_request_handler(remove_repository)
 326            .add_message_handler(start_language_server)
 327            .add_message_handler(update_language_server)
 328            .add_message_handler(update_diagnostic_summary)
 329            .add_message_handler(update_worktree_settings)
 330            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 331            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 332            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 333            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 334            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 335            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 336            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 337            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 338            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 339            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 340            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 341            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 342            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 343            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 344            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 345            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 346            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 347            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 348            .add_request_handler(
 349                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 350            )
 351            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 352            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 353            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 354            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 355            .add_request_handler(
 356                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 357            )
 358            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 359            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 360            .add_request_handler(
 361                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 362            )
 363            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 364            .add_request_handler(
 365                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 366            )
 367            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 368            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 369            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 370            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 371            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 372            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 373            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 374            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 375            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 376            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 377            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 378            .add_request_handler(
 379                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 380            )
 381            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 382            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 383            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 384            .add_request_handler(lsp_query)
 385            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 386            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 387            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 388            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 389            .add_message_handler(create_buffer_for_peer)
 390            .add_message_handler(create_image_for_peer)
 391            .add_request_handler(update_buffer)
 392            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 393            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 394            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 395            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 396            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 397            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 398            .add_message_handler(
 399                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 400            )
 401            .add_request_handler(get_users)
 402            .add_request_handler(fuzzy_search_users)
 403            .add_request_handler(request_contact)
 404            .add_request_handler(remove_contact)
 405            .add_request_handler(respond_to_contact_request)
 406            .add_message_handler(subscribe_to_channels)
 407            .add_request_handler(create_channel)
 408            .add_request_handler(delete_channel)
 409            .add_request_handler(invite_channel_member)
 410            .add_request_handler(remove_channel_member)
 411            .add_request_handler(set_channel_member_role)
 412            .add_request_handler(set_channel_visibility)
 413            .add_request_handler(rename_channel)
 414            .add_request_handler(join_channel_buffer)
 415            .add_request_handler(leave_channel_buffer)
 416            .add_message_handler(update_channel_buffer)
 417            .add_request_handler(rejoin_channel_buffers)
 418            .add_request_handler(get_channel_members)
 419            .add_request_handler(respond_to_channel_invite)
 420            .add_request_handler(join_channel)
 421            .add_request_handler(join_channel_chat)
 422            .add_message_handler(leave_channel_chat)
 423            .add_request_handler(send_channel_message)
 424            .add_request_handler(remove_channel_message)
 425            .add_request_handler(update_channel_message)
 426            .add_request_handler(get_channel_messages)
 427            .add_request_handler(get_channel_messages_by_id)
 428            .add_request_handler(get_notifications)
 429            .add_request_handler(mark_notification_as_read)
 430            .add_request_handler(move_channel)
 431            .add_request_handler(reorder_channel)
 432            .add_request_handler(follow)
 433            .add_message_handler(unfollow)
 434            .add_message_handler(update_followers)
 435            .add_message_handler(acknowledge_channel_message)
 436            .add_message_handler(acknowledge_buffer_version)
 437            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 438            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 439            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 440            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 441            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 442            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 443            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 444            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 445            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 446            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 447            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 448            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 449            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 450            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 451            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 452            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 453            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 454            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 455            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 456            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 457            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 458            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 459            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 460            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 461            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 462            .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
 463            .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
 464            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 465            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 466            .add_message_handler(update_context)
 467            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 468            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
 469
 470        Arc::new(server)
 471    }
 472
 473    pub async fn start(&self) -> Result<()> {
 474        let server_id = *self.id.lock();
 475        let app_state = self.app_state.clone();
 476        let peer = self.peer.clone();
 477        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 478        let pool = self.connection_pool.clone();
 479        let livekit_client = self.app_state.livekit_client.clone();
 480
 481        let span = info_span!("start server");
 482        self.app_state.executor.spawn_detached(
 483            async move {
 484                tracing::info!("waiting for cleanup timeout");
 485                timeout.await;
 486                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 487
 488                app_state
 489                    .db
 490                    .delete_stale_channel_chat_participants(
 491                        &app_state.config.zed_environment,
 492                        server_id,
 493                    )
 494                    .await
 495                    .trace_err();
 496
 497                if let Some((room_ids, channel_ids)) = app_state
 498                    .db
 499                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 500                    .await
 501                    .trace_err()
 502                {
 503                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 504                    tracing::info!(
 505                        stale_channel_buffer_count = channel_ids.len(),
 506                        "retrieved stale channel buffers"
 507                    );
 508
 509                    for channel_id in channel_ids {
 510                        if let Some(refreshed_channel_buffer) = app_state
 511                            .db
 512                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 513                            .await
 514                            .trace_err()
 515                        {
 516                            for connection_id in refreshed_channel_buffer.connection_ids {
 517                                peer.send(
 518                                    connection_id,
 519                                    proto::UpdateChannelBufferCollaborators {
 520                                        channel_id: channel_id.to_proto(),
 521                                        collaborators: refreshed_channel_buffer
 522                                            .collaborators
 523                                            .clone(),
 524                                    },
 525                                )
 526                                .trace_err();
 527                            }
 528                        }
 529                    }
 530
 531                    for room_id in room_ids {
 532                        let mut contacts_to_update = HashSet::default();
 533                        let mut canceled_calls_to_user_ids = Vec::new();
 534                        let mut livekit_room = String::new();
 535                        let mut delete_livekit_room = false;
 536
 537                        if let Some(mut refreshed_room) = app_state
 538                            .db
 539                            .clear_stale_room_participants(room_id, server_id)
 540                            .await
 541                            .trace_err()
 542                        {
 543                            tracing::info!(
 544                                room_id = room_id.0,
 545                                new_participant_count = refreshed_room.room.participants.len(),
 546                                "refreshed room"
 547                            );
 548                            room_updated(&refreshed_room.room, &peer);
 549                            if let Some(channel) = refreshed_room.channel.as_ref() {
 550                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 551                            }
 552                            contacts_to_update
 553                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 554                            contacts_to_update
 555                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 556                            canceled_calls_to_user_ids =
 557                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 558                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 559                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 560                        }
 561
 562                        {
 563                            let pool = pool.lock();
 564                            for canceled_user_id in canceled_calls_to_user_ids {
 565                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 566                                    peer.send(
 567                                        connection_id,
 568                                        proto::CallCanceled {
 569                                            room_id: room_id.to_proto(),
 570                                        },
 571                                    )
 572                                    .trace_err();
 573                                }
 574                            }
 575                        }
 576
 577                        for user_id in contacts_to_update {
 578                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 579                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 580                            if let Some((busy, contacts)) = busy.zip(contacts) {
 581                                let pool = pool.lock();
 582                                let updated_contact = contact_for_user(user_id, busy, &pool);
 583                                for contact in contacts {
 584                                    if let db::Contact::Accepted {
 585                                        user_id: contact_user_id,
 586                                        ..
 587                                    } = contact
 588                                    {
 589                                        for contact_conn_id in
 590                                            pool.user_connection_ids(contact_user_id)
 591                                        {
 592                                            peer.send(
 593                                                contact_conn_id,
 594                                                proto::UpdateContacts {
 595                                                    contacts: vec![updated_contact.clone()],
 596                                                    remove_contacts: Default::default(),
 597                                                    incoming_requests: Default::default(),
 598                                                    remove_incoming_requests: Default::default(),
 599                                                    outgoing_requests: Default::default(),
 600                                                    remove_outgoing_requests: Default::default(),
 601                                                },
 602                                            )
 603                                            .trace_err();
 604                                        }
 605                                    }
 606                                }
 607                            }
 608                        }
 609
 610                        if let Some(live_kit) = livekit_client.as_ref()
 611                            && delete_livekit_room
 612                        {
 613                            live_kit.delete_room(livekit_room).await.trace_err();
 614                        }
 615                    }
 616                }
 617
 618                app_state
 619                    .db
 620                    .delete_stale_channel_chat_participants(
 621                        &app_state.config.zed_environment,
 622                        server_id,
 623                    )
 624                    .await
 625                    .trace_err();
 626
 627                app_state
 628                    .db
 629                    .clear_old_worktree_entries(server_id)
 630                    .await
 631                    .trace_err();
 632
 633                app_state
 634                    .db
 635                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 636                    .await
 637                    .trace_err();
 638            }
 639            .instrument(span),
 640        );
 641        Ok(())
 642    }
 643
 644    pub fn teardown(&self) {
 645        self.peer.teardown();
 646        self.connection_pool.lock().reset();
 647        let _ = self.teardown.send(true);
 648    }
 649
 650    #[cfg(test)]
 651    pub fn reset(&self, id: ServerId) {
 652        self.teardown();
 653        *self.id.lock() = id;
 654        self.peer.reset(id.0 as u32);
 655        let _ = self.teardown.send(false);
 656    }
 657
 658    #[cfg(test)]
 659    pub fn id(&self) -> ServerId {
 660        *self.id.lock()
 661    }
 662
 663    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 664    where
 665        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 666        Fut: 'static + Send + Future<Output = Result<()>>,
 667        M: EnvelopedMessage,
 668    {
 669        let prev_handler = self.handlers.insert(
 670            TypeId::of::<M>(),
 671            Box::new(move |envelope, session, span| {
 672                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 673                let received_at = envelope.received_at;
 674                tracing::info!("message received");
 675                let start_time = Instant::now();
 676                let future = (handler)(
 677                    *envelope,
 678                    MessageContext {
 679                        session,
 680                        span: span.clone(),
 681                    },
 682                );
 683                async move {
 684                    let result = future.await;
 685                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 686                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 687                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 688                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 689                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 690                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 691                    match result {
 692                        Err(error) => {
 693                            tracing::error!(?error, "error handling message")
 694                        }
 695                        Ok(()) => tracing::info!("finished handling message"),
 696                    }
 697                }
 698                .boxed()
 699            }),
 700        );
 701        if prev_handler.is_some() {
 702            panic!("registered a handler for the same message twice");
 703        }
 704        self
 705    }
 706
 707    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 708    where
 709        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 710        Fut: 'static + Send + Future<Output = Result<()>>,
 711        M: EnvelopedMessage,
 712    {
 713        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 714        self
 715    }
 716
 717    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 718    where
 719        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 720        Fut: Send + Future<Output = Result<()>>,
 721        M: RequestMessage,
 722    {
 723        let handler = Arc::new(handler);
 724        self.add_handler(move |envelope, session| {
 725            let receipt = envelope.receipt();
 726            let handler = handler.clone();
 727            async move {
 728                let peer = session.peer.clone();
 729                let responded = Arc::new(AtomicBool::default());
 730                let response = Response {
 731                    peer: peer.clone(),
 732                    responded: responded.clone(),
 733                    receipt,
 734                };
 735                match (handler)(envelope.payload, response, session).await {
 736                    Ok(()) => {
 737                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 738                            Ok(())
 739                        } else {
 740                            Err(anyhow!("handler did not send a response"))?
 741                        }
 742                    }
 743                    Err(error) => {
 744                        let proto_err = match &error {
 745                            Error::Internal(err) => err.to_proto(),
 746                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 747                        };
 748                        peer.respond_with_error(receipt, proto_err)?;
 749                        Err(error)
 750                    }
 751                }
 752            }
 753        })
 754    }
 755
 756    pub fn handle_connection(
 757        self: &Arc<Self>,
 758        connection: Connection,
 759        address: String,
 760        principal: Principal,
 761        zed_version: ZedVersion,
 762        release_channel: Option<String>,
 763        user_agent: Option<String>,
 764        geoip_country_code: Option<String>,
 765        system_id: Option<String>,
 766        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 767        executor: Executor,
 768        connection_guard: Option<ConnectionGuard>,
 769    ) -> impl Future<Output = ()> + use<> {
 770        let this = self.clone();
 771        let span = info_span!("handle connection", %address,
 772            connection_id=field::Empty,
 773            user_id=field::Empty,
 774            login=field::Empty,
 775            impersonator=field::Empty,
 776            user_agent=field::Empty,
 777            geoip_country_code=field::Empty,
 778            release_channel=field::Empty,
 779        );
 780        principal.update_span(&span);
 781        if let Some(user_agent) = user_agent {
 782            span.record("user_agent", user_agent);
 783        }
 784        if let Some(release_channel) = release_channel {
 785            span.record("release_channel", release_channel);
 786        }
 787
 788        if let Some(country_code) = geoip_country_code.as_ref() {
 789            span.record("geoip_country_code", country_code);
 790        }
 791
 792        let mut teardown = self.teardown.subscribe();
 793        async move {
 794            if *teardown.borrow() {
 795                tracing::error!("server is tearing down");
 796                return;
 797            }
 798
 799            let (connection_id, handle_io, mut incoming_rx) =
 800                this.peer.add_connection(connection, {
 801                    let executor = executor.clone();
 802                    move |duration| executor.sleep(duration)
 803                });
 804            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 805
 806            tracing::info!("connection opened");
 807
 808            let session = Session {
 809                principal: principal.clone(),
 810                connection_id,
 811                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 812                peer: this.peer.clone(),
 813                connection_pool: this.connection_pool.clone(),
 814                app_state: this.app_state.clone(),
 815                geoip_country_code,
 816                system_id,
 817                _executor: executor.clone(),
 818            };
 819
 820            if let Err(error) = this
 821                .send_initial_client_update(
 822                    connection_id,
 823                    zed_version,
 824                    send_connection_id,
 825                    &session,
 826                )
 827                .await
 828            {
 829                tracing::error!(?error, "failed to send initial client update");
 830                return;
 831            }
 832            drop(connection_guard);
 833
 834            let handle_io = handle_io.fuse();
 835            futures::pin_mut!(handle_io);
 836
 837            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 838            // This prevents deadlocks when e.g., client A performs a request to client B and
 839            // client B performs a request to client A. If both clients stop processing further
 840            // messages until their respective request completes, they won't have a chance to
 841            // respond to the other client's request and cause a deadlock.
 842            //
 843            // This arrangement ensures we will attempt to process earlier messages first, but fall
 844            // back to processing messages arrived later in the spirit of making progress.
 845            const MAX_CONCURRENT_HANDLERS: usize = 256;
 846            let mut foreground_message_handlers = FuturesUnordered::new();
 847            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 848            let get_concurrent_handlers = {
 849                let concurrent_handlers = concurrent_handlers.clone();
 850                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 851            };
 852            loop {
 853                let next_message = async {
 854                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 855                    let message = incoming_rx.next().await;
 856                    // Cache the concurrent_handlers here, so that we know what the
 857                    // queue looks like as each handler starts
 858                    (permit, message, get_concurrent_handlers())
 859                }
 860                .fuse();
 861                futures::pin_mut!(next_message);
 862                futures::select_biased! {
 863                    _ = teardown.changed().fuse() => return,
 864                    result = handle_io => {
 865                        if let Err(error) = result {
 866                            tracing::error!(?error, "error handling I/O");
 867                        }
 868                        break;
 869                    }
 870                    _ = foreground_message_handlers.next() => {}
 871                    next_message = next_message => {
 872                        let (permit, message, concurrent_handlers) = next_message;
 873                        if let Some(message) = message {
 874                            let type_name = message.payload_type_name();
 875                            // note: we copy all the fields from the parent span so we can query them in the logs.
 876                            // (https://github.com/tokio-rs/tracing/issues/2670).
 877                            let span = tracing::info_span!("receive message",
 878                                %connection_id,
 879                                %address,
 880                                type_name,
 881                                concurrent_handlers,
 882                                user_id=field::Empty,
 883                                login=field::Empty,
 884                                impersonator=field::Empty,
 885                                lsp_query_request=field::Empty,
 886                                release_channel=field::Empty,
 887                                { TOTAL_DURATION_MS }=field::Empty,
 888                                { PROCESSING_DURATION_MS }=field::Empty,
 889                                { QUEUE_DURATION_MS }=field::Empty,
 890                                { HOST_WAITING_MS }=field::Empty
 891                            );
 892                            principal.update_span(&span);
 893                            let span_enter = span.enter();
 894                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 895                                let is_background = message.is_background();
 896                                let handle_message = (handler)(message, session.clone(), span.clone());
 897                                drop(span_enter);
 898
 899                                let handle_message = async move {
 900                                    handle_message.await;
 901                                    drop(permit);
 902                                }.instrument(span);
 903                                if is_background {
 904                                    executor.spawn_detached(handle_message);
 905                                } else {
 906                                    foreground_message_handlers.push(handle_message);
 907                                }
 908                            } else {
 909                                tracing::error!("no message handler");
 910                            }
 911                        } else {
 912                            tracing::info!("connection closed");
 913                            break;
 914                        }
 915                    }
 916                }
 917            }
 918
 919            drop(foreground_message_handlers);
 920            let concurrent_handlers = get_concurrent_handlers();
 921            tracing::info!(concurrent_handlers, "signing out");
 922            if let Err(error) = connection_lost(session, teardown, executor).await {
 923                tracing::error!(?error, "error signing out");
 924            }
 925        }
 926        .instrument(span)
 927    }
 928
 929    async fn send_initial_client_update(
 930        &self,
 931        connection_id: ConnectionId,
 932        zed_version: ZedVersion,
 933        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 934        session: &Session,
 935    ) -> Result<()> {
 936        self.peer.send(
 937            connection_id,
 938            proto::Hello {
 939                peer_id: Some(connection_id.into()),
 940            },
 941        )?;
 942        tracing::info!("sent hello message");
 943        if let Some(send_connection_id) = send_connection_id.take() {
 944            let _ = send_connection_id.send(connection_id);
 945        }
 946
 947        match &session.principal {
 948            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 949                if !user.connected_once {
 950                    self.peer.send(connection_id, proto::ShowContacts {})?;
 951                    self.app_state
 952                        .db
 953                        .set_user_connected_once(user.id, true)
 954                        .await?;
 955                }
 956
 957                let contacts = self.app_state.db.get_contacts(user.id).await?;
 958
 959                {
 960                    let mut pool = self.connection_pool.lock();
 961                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 962                    self.peer.send(
 963                        connection_id,
 964                        build_initial_contacts_update(contacts, &pool),
 965                    )?;
 966                }
 967
 968                if should_auto_subscribe_to_channels(&zed_version) {
 969                    subscribe_user_to_channels(user.id, session).await?;
 970                }
 971
 972                if let Some(incoming_call) =
 973                    self.app_state.db.incoming_call_for_user(user.id).await?
 974                {
 975                    self.peer.send(connection_id, incoming_call)?;
 976                }
 977
 978                update_user_contacts(user.id, session).await?;
 979            }
 980        }
 981
 982        Ok(())
 983    }
 984
 985    pub async fn invite_code_redeemed(
 986        self: &Arc<Self>,
 987        inviter_id: UserId,
 988        invitee_id: UserId,
 989    ) -> Result<()> {
 990        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
 991            && let Some(code) = &user.invite_code
 992        {
 993            let pool = self.connection_pool.lock();
 994            let invitee_contact = contact_for_user(invitee_id, false, &pool);
 995            for connection_id in pool.user_connection_ids(inviter_id) {
 996                self.peer.send(
 997                    connection_id,
 998                    proto::UpdateContacts {
 999                        contacts: vec![invitee_contact.clone()],
1000                        ..Default::default()
1001                    },
1002                )?;
1003                self.peer.send(
1004                    connection_id,
1005                    proto::UpdateInviteInfo {
1006                        url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1007                        count: user.invite_count as u32,
1008                    },
1009                )?;
1010            }
1011        }
1012        Ok(())
1013    }
1014
1015    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1016        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1017            && let Some(invite_code) = &user.invite_code
1018        {
1019            let pool = self.connection_pool.lock();
1020            for connection_id in pool.user_connection_ids(user_id) {
1021                self.peer.send(
1022                    connection_id,
1023                    proto::UpdateInviteInfo {
1024                        url: format!(
1025                            "{}{}",
1026                            self.app_state.config.invite_link_prefix, invite_code
1027                        ),
1028                        count: user.invite_count as u32,
1029                    },
1030                )?;
1031            }
1032        }
1033        Ok(())
1034    }
1035
1036    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1037        ServerSnapshot {
1038            connection_pool: ConnectionPoolGuard {
1039                guard: self.connection_pool.lock(),
1040                _not_send: PhantomData,
1041            },
1042            peer: &self.peer,
1043        }
1044    }
1045}
1046
1047impl Deref for ConnectionPoolGuard<'_> {
1048    type Target = ConnectionPool;
1049
1050    fn deref(&self) -> &Self::Target {
1051        &self.guard
1052    }
1053}
1054
1055impl DerefMut for ConnectionPoolGuard<'_> {
1056    fn deref_mut(&mut self) -> &mut Self::Target {
1057        &mut self.guard
1058    }
1059}
1060
1061impl Drop for ConnectionPoolGuard<'_> {
1062    fn drop(&mut self) {
1063        #[cfg(test)]
1064        self.check_invariants();
1065    }
1066}
1067
1068fn broadcast<F>(
1069    sender_id: Option<ConnectionId>,
1070    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1071    mut f: F,
1072) where
1073    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1074{
1075    for receiver_id in receiver_ids {
1076        if Some(receiver_id) != sender_id
1077            && let Err(error) = f(receiver_id)
1078        {
1079            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1080        }
1081    }
1082}
1083
1084pub struct ProtocolVersion(u32);
1085
1086impl Header for ProtocolVersion {
1087    fn name() -> &'static HeaderName {
1088        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1089        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1090    }
1091
1092    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1093    where
1094        Self: Sized,
1095        I: Iterator<Item = &'i axum::http::HeaderValue>,
1096    {
1097        let version = values
1098            .next()
1099            .ok_or_else(axum::headers::Error::invalid)?
1100            .to_str()
1101            .map_err(|_| axum::headers::Error::invalid())?
1102            .parse()
1103            .map_err(|_| axum::headers::Error::invalid())?;
1104        Ok(Self(version))
1105    }
1106
1107    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1108        values.extend([self.0.to_string().parse().unwrap()]);
1109    }
1110}
1111
1112pub struct AppVersionHeader(Version);
1113impl Header for AppVersionHeader {
1114    fn name() -> &'static HeaderName {
1115        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1116        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1117    }
1118
1119    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1120    where
1121        Self: Sized,
1122        I: Iterator<Item = &'i axum::http::HeaderValue>,
1123    {
1124        let version = values
1125            .next()
1126            .ok_or_else(axum::headers::Error::invalid)?
1127            .to_str()
1128            .map_err(|_| axum::headers::Error::invalid())?
1129            .parse()
1130            .map_err(|_| axum::headers::Error::invalid())?;
1131        Ok(Self(version))
1132    }
1133
1134    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1135        values.extend([self.0.to_string().parse().unwrap()]);
1136    }
1137}
1138
1139#[derive(Debug)]
1140pub struct ReleaseChannelHeader(String);
1141
1142impl Header for ReleaseChannelHeader {
1143    fn name() -> &'static HeaderName {
1144        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1145        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1146    }
1147
1148    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1149    where
1150        Self: Sized,
1151        I: Iterator<Item = &'i axum::http::HeaderValue>,
1152    {
1153        Ok(Self(
1154            values
1155                .next()
1156                .ok_or_else(axum::headers::Error::invalid)?
1157                .to_str()
1158                .map_err(|_| axum::headers::Error::invalid())?
1159                .to_owned(),
1160        ))
1161    }
1162
1163    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1164        values.extend([self.0.parse().unwrap()]);
1165    }
1166}
1167
1168pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1169    Router::new()
1170        .route("/rpc", get(handle_websocket_request))
1171        .layer(
1172            ServiceBuilder::new()
1173                .layer(Extension(server.app_state.clone()))
1174                .layer(middleware::from_fn(auth::validate_header)),
1175        )
1176        .route("/metrics", get(handle_metrics))
1177        .layer(Extension(server))
1178}
1179
1180pub async fn handle_websocket_request(
1181    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1182    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1183    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1184    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1185    Extension(server): Extension<Arc<Server>>,
1186    Extension(principal): Extension<Principal>,
1187    user_agent: Option<TypedHeader<UserAgent>>,
1188    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1189    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1190    ws: WebSocketUpgrade,
1191) -> axum::response::Response {
1192    if protocol_version != rpc::PROTOCOL_VERSION {
1193        return (
1194            StatusCode::UPGRADE_REQUIRED,
1195            "client must be upgraded".to_string(),
1196        )
1197            .into_response();
1198    }
1199
1200    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1201        return (
1202            StatusCode::UPGRADE_REQUIRED,
1203            "no version header found".to_string(),
1204        )
1205            .into_response();
1206    };
1207
1208    let release_channel = release_channel_header.map(|header| header.0.0);
1209
1210    if !version.can_collaborate() {
1211        return (
1212            StatusCode::UPGRADE_REQUIRED,
1213            "client must be upgraded".to_string(),
1214        )
1215            .into_response();
1216    }
1217
1218    let socket_address = socket_address.to_string();
1219
1220    // Acquire connection guard before WebSocket upgrade
1221    let connection_guard = match ConnectionGuard::try_acquire() {
1222        Ok(guard) => guard,
1223        Err(()) => {
1224            return (
1225                StatusCode::SERVICE_UNAVAILABLE,
1226                "Too many concurrent connections",
1227            )
1228                .into_response();
1229        }
1230    };
1231
1232    ws.on_upgrade(move |socket| {
1233        let socket = socket
1234            .map_ok(to_tungstenite_message)
1235            .err_into()
1236            .with(|message| async move { to_axum_message(message) });
1237        let connection = Connection::new(Box::pin(socket));
1238        async move {
1239            server
1240                .handle_connection(
1241                    connection,
1242                    socket_address,
1243                    principal,
1244                    version,
1245                    release_channel,
1246                    user_agent.map(|header| header.to_string()),
1247                    country_code_header.map(|header| header.to_string()),
1248                    system_id_header.map(|header| header.to_string()),
1249                    None,
1250                    Executor::Production,
1251                    Some(connection_guard),
1252                )
1253                .await;
1254        }
1255    })
1256}
1257
1258pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1259    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1260    let connections_metric = CONNECTIONS_METRIC
1261        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1262
1263    let connections = server
1264        .connection_pool
1265        .lock()
1266        .connections()
1267        .filter(|connection| !connection.admin)
1268        .count();
1269    connections_metric.set(connections as _);
1270
1271    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1272    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1273        register_int_gauge!(
1274            "shared_projects",
1275            "number of open projects with one or more guests"
1276        )
1277        .unwrap()
1278    });
1279
1280    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1281    shared_projects_metric.set(shared_projects as _);
1282
1283    let encoder = prometheus::TextEncoder::new();
1284    let metric_families = prometheus::gather();
1285    let encoded_metrics = encoder
1286        .encode_to_string(&metric_families)
1287        .map_err(|err| anyhow!("{err}"))?;
1288    Ok(encoded_metrics)
1289}
1290
1291#[instrument(err, skip(executor))]
1292async fn connection_lost(
1293    session: Session,
1294    mut teardown: watch::Receiver<bool>,
1295    executor: Executor,
1296) -> Result<()> {
1297    session.peer.disconnect(session.connection_id);
1298    session
1299        .connection_pool()
1300        .await
1301        .remove_connection(session.connection_id)?;
1302
1303    session
1304        .db()
1305        .await
1306        .connection_lost(session.connection_id)
1307        .await
1308        .trace_err();
1309
1310    futures::select_biased! {
1311        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1312
1313            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1314            leave_room_for_session(&session, session.connection_id).await.trace_err();
1315            leave_channel_buffers_for_session(&session)
1316                .await
1317                .trace_err();
1318
1319            if !session
1320                .connection_pool()
1321                .await
1322                .is_user_online(session.user_id())
1323            {
1324                let db = session.db().await;
1325                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1326                    room_updated(&room, &session.peer);
1327                }
1328            }
1329
1330            update_user_contacts(session.user_id(), &session).await?;
1331        },
1332        _ = teardown.changed().fuse() => {}
1333    }
1334
1335    Ok(())
1336}
1337
1338/// Acknowledges a ping from a client, used to keep the connection alive.
1339async fn ping(
1340    _: proto::Ping,
1341    response: Response<proto::Ping>,
1342    _session: MessageContext,
1343) -> Result<()> {
1344    response.send(proto::Ack {})?;
1345    Ok(())
1346}
1347
1348/// Creates a new room for calling (outside of channels)
1349async fn create_room(
1350    _request: proto::CreateRoom,
1351    response: Response<proto::CreateRoom>,
1352    session: MessageContext,
1353) -> Result<()> {
1354    let livekit_room = nanoid::nanoid!(30);
1355
1356    let live_kit_connection_info = util::maybe!(async {
1357        let live_kit = session.app_state.livekit_client.as_ref();
1358        let live_kit = live_kit?;
1359        let user_id = session.user_id().to_string();
1360
1361        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1362
1363        Some(proto::LiveKitConnectionInfo {
1364            server_url: live_kit.url().into(),
1365            token,
1366            can_publish: true,
1367        })
1368    })
1369    .await;
1370
1371    let room = session
1372        .db()
1373        .await
1374        .create_room(session.user_id(), session.connection_id, &livekit_room)
1375        .await?;
1376
1377    response.send(proto::CreateRoomResponse {
1378        room: Some(room.clone()),
1379        live_kit_connection_info,
1380    })?;
1381
1382    update_user_contacts(session.user_id(), &session).await?;
1383    Ok(())
1384}
1385
1386/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1387async fn join_room(
1388    request: proto::JoinRoom,
1389    response: Response<proto::JoinRoom>,
1390    session: MessageContext,
1391) -> Result<()> {
1392    let room_id = RoomId::from_proto(request.id);
1393
1394    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1395
1396    if let Some(channel_id) = channel_id {
1397        return join_channel_internal(channel_id, Box::new(response), session).await;
1398    }
1399
1400    let joined_room = {
1401        let room = session
1402            .db()
1403            .await
1404            .join_room(room_id, session.user_id(), session.connection_id)
1405            .await?;
1406        room_updated(&room.room, &session.peer);
1407        room.into_inner()
1408    };
1409
1410    for connection_id in session
1411        .connection_pool()
1412        .await
1413        .user_connection_ids(session.user_id())
1414    {
1415        session
1416            .peer
1417            .send(
1418                connection_id,
1419                proto::CallCanceled {
1420                    room_id: room_id.to_proto(),
1421                },
1422            )
1423            .trace_err();
1424    }
1425
1426    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1427    {
1428        live_kit
1429            .room_token(
1430                &joined_room.room.livekit_room,
1431                &session.user_id().to_string(),
1432            )
1433            .trace_err()
1434            .map(|token| proto::LiveKitConnectionInfo {
1435                server_url: live_kit.url().into(),
1436                token,
1437                can_publish: true,
1438            })
1439    } else {
1440        None
1441    };
1442
1443    response.send(proto::JoinRoomResponse {
1444        room: Some(joined_room.room),
1445        channel_id: None,
1446        live_kit_connection_info,
1447    })?;
1448
1449    update_user_contacts(session.user_id(), &session).await?;
1450    Ok(())
1451}
1452
1453/// Rejoin room is used to reconnect to a room after connection errors.
1454async fn rejoin_room(
1455    request: proto::RejoinRoom,
1456    response: Response<proto::RejoinRoom>,
1457    session: MessageContext,
1458) -> Result<()> {
1459    let room;
1460    let channel;
1461    {
1462        let mut rejoined_room = session
1463            .db()
1464            .await
1465            .rejoin_room(request, session.user_id(), session.connection_id)
1466            .await?;
1467
1468        response.send(proto::RejoinRoomResponse {
1469            room: Some(rejoined_room.room.clone()),
1470            reshared_projects: rejoined_room
1471                .reshared_projects
1472                .iter()
1473                .map(|project| proto::ResharedProject {
1474                    id: project.id.to_proto(),
1475                    collaborators: project
1476                        .collaborators
1477                        .iter()
1478                        .map(|collaborator| collaborator.to_proto())
1479                        .collect(),
1480                })
1481                .collect(),
1482            rejoined_projects: rejoined_room
1483                .rejoined_projects
1484                .iter()
1485                .map(|rejoined_project| rejoined_project.to_proto())
1486                .collect(),
1487        })?;
1488        room_updated(&rejoined_room.room, &session.peer);
1489
1490        for project in &rejoined_room.reshared_projects {
1491            for collaborator in &project.collaborators {
1492                session
1493                    .peer
1494                    .send(
1495                        collaborator.connection_id,
1496                        proto::UpdateProjectCollaborator {
1497                            project_id: project.id.to_proto(),
1498                            old_peer_id: Some(project.old_connection_id.into()),
1499                            new_peer_id: Some(session.connection_id.into()),
1500                        },
1501                    )
1502                    .trace_err();
1503            }
1504
1505            broadcast(
1506                Some(session.connection_id),
1507                project
1508                    .collaborators
1509                    .iter()
1510                    .map(|collaborator| collaborator.connection_id),
1511                |connection_id| {
1512                    session.peer.forward_send(
1513                        session.connection_id,
1514                        connection_id,
1515                        proto::UpdateProject {
1516                            project_id: project.id.to_proto(),
1517                            worktrees: project.worktrees.clone(),
1518                        },
1519                    )
1520                },
1521            );
1522        }
1523
1524        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1525
1526        let rejoined_room = rejoined_room.into_inner();
1527
1528        room = rejoined_room.room;
1529        channel = rejoined_room.channel;
1530    }
1531
1532    if let Some(channel) = channel {
1533        channel_updated(
1534            &channel,
1535            &room,
1536            &session.peer,
1537            &*session.connection_pool().await,
1538        );
1539    }
1540
1541    update_user_contacts(session.user_id(), &session).await?;
1542    Ok(())
1543}
1544
1545fn notify_rejoined_projects(
1546    rejoined_projects: &mut Vec<RejoinedProject>,
1547    session: &Session,
1548) -> Result<()> {
1549    for project in rejoined_projects.iter() {
1550        for collaborator in &project.collaborators {
1551            session
1552                .peer
1553                .send(
1554                    collaborator.connection_id,
1555                    proto::UpdateProjectCollaborator {
1556                        project_id: project.id.to_proto(),
1557                        old_peer_id: Some(project.old_connection_id.into()),
1558                        new_peer_id: Some(session.connection_id.into()),
1559                    },
1560                )
1561                .trace_err();
1562        }
1563    }
1564
1565    for project in rejoined_projects {
1566        for worktree in mem::take(&mut project.worktrees) {
1567            // Stream this worktree's entries.
1568            let message = proto::UpdateWorktree {
1569                project_id: project.id.to_proto(),
1570                worktree_id: worktree.id,
1571                abs_path: worktree.abs_path.clone(),
1572                root_name: worktree.root_name,
1573                updated_entries: worktree.updated_entries,
1574                removed_entries: worktree.removed_entries,
1575                scan_id: worktree.scan_id,
1576                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1577                updated_repositories: worktree.updated_repositories,
1578                removed_repositories: worktree.removed_repositories,
1579            };
1580            for update in proto::split_worktree_update(message) {
1581                session.peer.send(session.connection_id, update)?;
1582            }
1583
1584            // Stream this worktree's diagnostics.
1585            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1586            if let Some(summary) = worktree_diagnostics.next() {
1587                let message = proto::UpdateDiagnosticSummary {
1588                    project_id: project.id.to_proto(),
1589                    worktree_id: worktree.id,
1590                    summary: Some(summary),
1591                    more_summaries: worktree_diagnostics.collect(),
1592                };
1593                session.peer.send(session.connection_id, message)?;
1594            }
1595
1596            for settings_file in worktree.settings_files {
1597                session.peer.send(
1598                    session.connection_id,
1599                    proto::UpdateWorktreeSettings {
1600                        project_id: project.id.to_proto(),
1601                        worktree_id: worktree.id,
1602                        path: settings_file.path,
1603                        content: Some(settings_file.content),
1604                        kind: Some(settings_file.kind.to_proto().into()),
1605                    },
1606                )?;
1607            }
1608        }
1609
1610        for repository in mem::take(&mut project.updated_repositories) {
1611            for update in split_repository_update(repository) {
1612                session.peer.send(session.connection_id, update)?;
1613            }
1614        }
1615
1616        for id in mem::take(&mut project.removed_repositories) {
1617            session.peer.send(
1618                session.connection_id,
1619                proto::RemoveRepository {
1620                    project_id: project.id.to_proto(),
1621                    id,
1622                },
1623            )?;
1624        }
1625    }
1626
1627    Ok(())
1628}
1629
1630/// leave room disconnects from the room.
1631async fn leave_room(
1632    _: proto::LeaveRoom,
1633    response: Response<proto::LeaveRoom>,
1634    session: MessageContext,
1635) -> Result<()> {
1636    leave_room_for_session(&session, session.connection_id).await?;
1637    response.send(proto::Ack {})?;
1638    Ok(())
1639}
1640
1641/// Updates the permissions of someone else in the room.
1642async fn set_room_participant_role(
1643    request: proto::SetRoomParticipantRole,
1644    response: Response<proto::SetRoomParticipantRole>,
1645    session: MessageContext,
1646) -> Result<()> {
1647    let user_id = UserId::from_proto(request.user_id);
1648    let role = ChannelRole::from(request.role());
1649
1650    let (livekit_room, can_publish) = {
1651        let room = session
1652            .db()
1653            .await
1654            .set_room_participant_role(
1655                session.user_id(),
1656                RoomId::from_proto(request.room_id),
1657                user_id,
1658                role,
1659            )
1660            .await?;
1661
1662        let livekit_room = room.livekit_room.clone();
1663        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1664        room_updated(&room, &session.peer);
1665        (livekit_room, can_publish)
1666    };
1667
1668    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1669        live_kit
1670            .update_participant(
1671                livekit_room.clone(),
1672                request.user_id.to_string(),
1673                livekit_api::proto::ParticipantPermission {
1674                    can_subscribe: true,
1675                    can_publish,
1676                    can_publish_data: can_publish,
1677                    hidden: false,
1678                    recorder: false,
1679                },
1680            )
1681            .await
1682            .trace_err();
1683    }
1684
1685    response.send(proto::Ack {})?;
1686    Ok(())
1687}
1688
1689/// Call someone else into the current room
1690async fn call(
1691    request: proto::Call,
1692    response: Response<proto::Call>,
1693    session: MessageContext,
1694) -> Result<()> {
1695    let room_id = RoomId::from_proto(request.room_id);
1696    let calling_user_id = session.user_id();
1697    let calling_connection_id = session.connection_id;
1698    let called_user_id = UserId::from_proto(request.called_user_id);
1699    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1700    if !session
1701        .db()
1702        .await
1703        .has_contact(calling_user_id, called_user_id)
1704        .await?
1705    {
1706        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1707    }
1708
1709    let incoming_call = {
1710        let (room, incoming_call) = &mut *session
1711            .db()
1712            .await
1713            .call(
1714                room_id,
1715                calling_user_id,
1716                calling_connection_id,
1717                called_user_id,
1718                initial_project_id,
1719            )
1720            .await?;
1721        room_updated(room, &session.peer);
1722        mem::take(incoming_call)
1723    };
1724    update_user_contacts(called_user_id, &session).await?;
1725
1726    let mut calls = session
1727        .connection_pool()
1728        .await
1729        .user_connection_ids(called_user_id)
1730        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1731        .collect::<FuturesUnordered<_>>();
1732
1733    while let Some(call_response) = calls.next().await {
1734        match call_response.as_ref() {
1735            Ok(_) => {
1736                response.send(proto::Ack {})?;
1737                return Ok(());
1738            }
1739            Err(_) => {
1740                call_response.trace_err();
1741            }
1742        }
1743    }
1744
1745    {
1746        let room = session
1747            .db()
1748            .await
1749            .call_failed(room_id, called_user_id)
1750            .await?;
1751        room_updated(&room, &session.peer);
1752    }
1753    update_user_contacts(called_user_id, &session).await?;
1754
1755    Err(anyhow!("failed to ring user"))?
1756}
1757
1758/// Cancel an outgoing call.
1759async fn cancel_call(
1760    request: proto::CancelCall,
1761    response: Response<proto::CancelCall>,
1762    session: MessageContext,
1763) -> Result<()> {
1764    let called_user_id = UserId::from_proto(request.called_user_id);
1765    let room_id = RoomId::from_proto(request.room_id);
1766    {
1767        let room = session
1768            .db()
1769            .await
1770            .cancel_call(room_id, session.connection_id, called_user_id)
1771            .await?;
1772        room_updated(&room, &session.peer);
1773    }
1774
1775    for connection_id in session
1776        .connection_pool()
1777        .await
1778        .user_connection_ids(called_user_id)
1779    {
1780        session
1781            .peer
1782            .send(
1783                connection_id,
1784                proto::CallCanceled {
1785                    room_id: room_id.to_proto(),
1786                },
1787            )
1788            .trace_err();
1789    }
1790    response.send(proto::Ack {})?;
1791
1792    update_user_contacts(called_user_id, &session).await?;
1793    Ok(())
1794}
1795
1796/// Decline an incoming call.
1797async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1798    let room_id = RoomId::from_proto(message.room_id);
1799    {
1800        let room = session
1801            .db()
1802            .await
1803            .decline_call(Some(room_id), session.user_id())
1804            .await?
1805            .context("declining call")?;
1806        room_updated(&room, &session.peer);
1807    }
1808
1809    for connection_id in session
1810        .connection_pool()
1811        .await
1812        .user_connection_ids(session.user_id())
1813    {
1814        session
1815            .peer
1816            .send(
1817                connection_id,
1818                proto::CallCanceled {
1819                    room_id: room_id.to_proto(),
1820                },
1821            )
1822            .trace_err();
1823    }
1824    update_user_contacts(session.user_id(), &session).await?;
1825    Ok(())
1826}
1827
1828/// Updates other participants in the room with your current location.
1829async fn update_participant_location(
1830    request: proto::UpdateParticipantLocation,
1831    response: Response<proto::UpdateParticipantLocation>,
1832    session: MessageContext,
1833) -> Result<()> {
1834    let room_id = RoomId::from_proto(request.room_id);
1835    let location = request.location.context("invalid location")?;
1836
1837    let db = session.db().await;
1838    let room = db
1839        .update_room_participant_location(room_id, session.connection_id, location)
1840        .await?;
1841
1842    room_updated(&room, &session.peer);
1843    response.send(proto::Ack {})?;
1844    Ok(())
1845}
1846
1847/// Share a project into the room.
1848async fn share_project(
1849    request: proto::ShareProject,
1850    response: Response<proto::ShareProject>,
1851    session: MessageContext,
1852) -> Result<()> {
1853    let (project_id, room) = &*session
1854        .db()
1855        .await
1856        .share_project(
1857            RoomId::from_proto(request.room_id),
1858            session.connection_id,
1859            &request.worktrees,
1860            request.is_ssh_project,
1861            request.windows_paths.unwrap_or(false),
1862        )
1863        .await?;
1864    response.send(proto::ShareProjectResponse {
1865        project_id: project_id.to_proto(),
1866    })?;
1867    room_updated(room, &session.peer);
1868
1869    Ok(())
1870}
1871
1872/// Unshare a project from the room.
1873async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1874    let project_id = ProjectId::from_proto(message.project_id);
1875    unshare_project_internal(project_id, session.connection_id, &session).await
1876}
1877
1878async fn unshare_project_internal(
1879    project_id: ProjectId,
1880    connection_id: ConnectionId,
1881    session: &Session,
1882) -> Result<()> {
1883    let delete = {
1884        let room_guard = session
1885            .db()
1886            .await
1887            .unshare_project(project_id, connection_id)
1888            .await?;
1889
1890        let (delete, room, guest_connection_ids) = &*room_guard;
1891
1892        let message = proto::UnshareProject {
1893            project_id: project_id.to_proto(),
1894        };
1895
1896        broadcast(
1897            Some(connection_id),
1898            guest_connection_ids.iter().copied(),
1899            |conn_id| session.peer.send(conn_id, message.clone()),
1900        );
1901        if let Some(room) = room {
1902            room_updated(room, &session.peer);
1903        }
1904
1905        *delete
1906    };
1907
1908    if delete {
1909        let db = session.db().await;
1910        db.delete_project(project_id).await?;
1911    }
1912
1913    Ok(())
1914}
1915
1916/// Join someone elses shared project.
1917async fn join_project(
1918    request: proto::JoinProject,
1919    response: Response<proto::JoinProject>,
1920    session: MessageContext,
1921) -> Result<()> {
1922    let project_id = ProjectId::from_proto(request.project_id);
1923
1924    tracing::info!(%project_id, "join project");
1925
1926    let db = session.db().await;
1927    let (project, replica_id) = &mut *db
1928        .join_project(
1929            project_id,
1930            session.connection_id,
1931            session.user_id(),
1932            request.committer_name.clone(),
1933            request.committer_email.clone(),
1934        )
1935        .await?;
1936    drop(db);
1937    tracing::info!(%project_id, "join remote project");
1938    let collaborators = project
1939        .collaborators
1940        .iter()
1941        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1942        .map(|collaborator| collaborator.to_proto())
1943        .collect::<Vec<_>>();
1944    let project_id = project.id;
1945    let guest_user_id = session.user_id();
1946
1947    let worktrees = project
1948        .worktrees
1949        .iter()
1950        .map(|(id, worktree)| proto::WorktreeMetadata {
1951            id: *id,
1952            root_name: worktree.root_name.clone(),
1953            visible: worktree.visible,
1954            abs_path: worktree.abs_path.clone(),
1955        })
1956        .collect::<Vec<_>>();
1957
1958    let add_project_collaborator = proto::AddProjectCollaborator {
1959        project_id: project_id.to_proto(),
1960        collaborator: Some(proto::Collaborator {
1961            peer_id: Some(session.connection_id.into()),
1962            replica_id: replica_id.0 as u32,
1963            user_id: guest_user_id.to_proto(),
1964            is_host: false,
1965            committer_name: request.committer_name.clone(),
1966            committer_email: request.committer_email.clone(),
1967        }),
1968    };
1969
1970    for collaborator in &collaborators {
1971        session
1972            .peer
1973            .send(
1974                collaborator.peer_id.unwrap().into(),
1975                add_project_collaborator.clone(),
1976            )
1977            .trace_err();
1978    }
1979
1980    // First, we send the metadata associated with each worktree.
1981    let (language_servers, language_server_capabilities) = project
1982        .language_servers
1983        .clone()
1984        .into_iter()
1985        .map(|server| (server.server, server.capabilities))
1986        .unzip();
1987    response.send(proto::JoinProjectResponse {
1988        project_id: project.id.0 as u64,
1989        worktrees,
1990        replica_id: replica_id.0 as u32,
1991        collaborators,
1992        language_servers,
1993        language_server_capabilities,
1994        role: project.role.into(),
1995        windows_paths: project.path_style == PathStyle::Windows,
1996    })?;
1997
1998    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1999        // Stream this worktree's entries.
2000        let message = proto::UpdateWorktree {
2001            project_id: project_id.to_proto(),
2002            worktree_id,
2003            abs_path: worktree.abs_path.clone(),
2004            root_name: worktree.root_name,
2005            updated_entries: worktree.entries,
2006            removed_entries: Default::default(),
2007            scan_id: worktree.scan_id,
2008            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2009            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2010            removed_repositories: Default::default(),
2011        };
2012        for update in proto::split_worktree_update(message) {
2013            session.peer.send(session.connection_id, update.clone())?;
2014        }
2015
2016        // Stream this worktree's diagnostics.
2017        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2018        if let Some(summary) = worktree_diagnostics.next() {
2019            let message = proto::UpdateDiagnosticSummary {
2020                project_id: project.id.to_proto(),
2021                worktree_id: worktree.id,
2022                summary: Some(summary),
2023                more_summaries: worktree_diagnostics.collect(),
2024            };
2025            session.peer.send(session.connection_id, message)?;
2026        }
2027
2028        for settings_file in worktree.settings_files {
2029            session.peer.send(
2030                session.connection_id,
2031                proto::UpdateWorktreeSettings {
2032                    project_id: project_id.to_proto(),
2033                    worktree_id: worktree.id,
2034                    path: settings_file.path,
2035                    content: Some(settings_file.content),
2036                    kind: Some(settings_file.kind.to_proto() as i32),
2037                },
2038            )?;
2039        }
2040    }
2041
2042    for repository in mem::take(&mut project.repositories) {
2043        for update in split_repository_update(repository) {
2044            session.peer.send(session.connection_id, update)?;
2045        }
2046    }
2047
2048    for language_server in &project.language_servers {
2049        session.peer.send(
2050            session.connection_id,
2051            proto::UpdateLanguageServer {
2052                project_id: project_id.to_proto(),
2053                server_name: Some(language_server.server.name.clone()),
2054                language_server_id: language_server.server.id,
2055                variant: Some(
2056                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2057                        proto::LspDiskBasedDiagnosticsUpdated {},
2058                    ),
2059                ),
2060            },
2061        )?;
2062    }
2063
2064    Ok(())
2065}
2066
2067/// Leave someone elses shared project.
2068async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2069    let sender_id = session.connection_id;
2070    let project_id = ProjectId::from_proto(request.project_id);
2071    let db = session.db().await;
2072
2073    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2074    tracing::info!(
2075        %project_id,
2076        "leave project"
2077    );
2078
2079    project_left(project, &session);
2080    if let Some(room) = room {
2081        room_updated(room, &session.peer);
2082    }
2083
2084    Ok(())
2085}
2086
2087/// Updates other participants with changes to the project
2088async fn update_project(
2089    request: proto::UpdateProject,
2090    response: Response<proto::UpdateProject>,
2091    session: MessageContext,
2092) -> Result<()> {
2093    let project_id = ProjectId::from_proto(request.project_id);
2094    let (room, guest_connection_ids) = &*session
2095        .db()
2096        .await
2097        .update_project(project_id, session.connection_id, &request.worktrees)
2098        .await?;
2099    broadcast(
2100        Some(session.connection_id),
2101        guest_connection_ids.iter().copied(),
2102        |connection_id| {
2103            session
2104                .peer
2105                .forward_send(session.connection_id, connection_id, request.clone())
2106        },
2107    );
2108    if let Some(room) = room {
2109        room_updated(room, &session.peer);
2110    }
2111    response.send(proto::Ack {})?;
2112
2113    Ok(())
2114}
2115
2116/// Updates other participants with changes to the worktree
2117async fn update_worktree(
2118    request: proto::UpdateWorktree,
2119    response: Response<proto::UpdateWorktree>,
2120    session: MessageContext,
2121) -> Result<()> {
2122    let guest_connection_ids = session
2123        .db()
2124        .await
2125        .update_worktree(&request, session.connection_id)
2126        .await?;
2127
2128    broadcast(
2129        Some(session.connection_id),
2130        guest_connection_ids.iter().copied(),
2131        |connection_id| {
2132            session
2133                .peer
2134                .forward_send(session.connection_id, connection_id, request.clone())
2135        },
2136    );
2137    response.send(proto::Ack {})?;
2138    Ok(())
2139}
2140
2141async fn update_repository(
2142    request: proto::UpdateRepository,
2143    response: Response<proto::UpdateRepository>,
2144    session: MessageContext,
2145) -> Result<()> {
2146    let guest_connection_ids = session
2147        .db()
2148        .await
2149        .update_repository(&request, session.connection_id)
2150        .await?;
2151
2152    broadcast(
2153        Some(session.connection_id),
2154        guest_connection_ids.iter().copied(),
2155        |connection_id| {
2156            session
2157                .peer
2158                .forward_send(session.connection_id, connection_id, request.clone())
2159        },
2160    );
2161    response.send(proto::Ack {})?;
2162    Ok(())
2163}
2164
2165async fn remove_repository(
2166    request: proto::RemoveRepository,
2167    response: Response<proto::RemoveRepository>,
2168    session: MessageContext,
2169) -> Result<()> {
2170    let guest_connection_ids = session
2171        .db()
2172        .await
2173        .remove_repository(&request, session.connection_id)
2174        .await?;
2175
2176    broadcast(
2177        Some(session.connection_id),
2178        guest_connection_ids.iter().copied(),
2179        |connection_id| {
2180            session
2181                .peer
2182                .forward_send(session.connection_id, connection_id, request.clone())
2183        },
2184    );
2185    response.send(proto::Ack {})?;
2186    Ok(())
2187}
2188
2189/// Updates other participants with changes to the diagnostics
2190async fn update_diagnostic_summary(
2191    message: proto::UpdateDiagnosticSummary,
2192    session: MessageContext,
2193) -> Result<()> {
2194    let guest_connection_ids = session
2195        .db()
2196        .await
2197        .update_diagnostic_summary(&message, session.connection_id)
2198        .await?;
2199
2200    broadcast(
2201        Some(session.connection_id),
2202        guest_connection_ids.iter().copied(),
2203        |connection_id| {
2204            session
2205                .peer
2206                .forward_send(session.connection_id, connection_id, message.clone())
2207        },
2208    );
2209
2210    Ok(())
2211}
2212
2213/// Updates other participants with changes to the worktree settings
2214async fn update_worktree_settings(
2215    message: proto::UpdateWorktreeSettings,
2216    session: MessageContext,
2217) -> Result<()> {
2218    let guest_connection_ids = session
2219        .db()
2220        .await
2221        .update_worktree_settings(&message, session.connection_id)
2222        .await?;
2223
2224    broadcast(
2225        Some(session.connection_id),
2226        guest_connection_ids.iter().copied(),
2227        |connection_id| {
2228            session
2229                .peer
2230                .forward_send(session.connection_id, connection_id, message.clone())
2231        },
2232    );
2233
2234    Ok(())
2235}
2236
2237/// Notify other participants that a language server has started.
2238async fn start_language_server(
2239    request: proto::StartLanguageServer,
2240    session: MessageContext,
2241) -> Result<()> {
2242    let guest_connection_ids = session
2243        .db()
2244        .await
2245        .start_language_server(&request, session.connection_id)
2246        .await?;
2247
2248    broadcast(
2249        Some(session.connection_id),
2250        guest_connection_ids.iter().copied(),
2251        |connection_id| {
2252            session
2253                .peer
2254                .forward_send(session.connection_id, connection_id, request.clone())
2255        },
2256    );
2257    Ok(())
2258}
2259
2260/// Notify other participants that a language server has changed.
2261async fn update_language_server(
2262    request: proto::UpdateLanguageServer,
2263    session: MessageContext,
2264) -> Result<()> {
2265    let project_id = ProjectId::from_proto(request.project_id);
2266    let db = session.db().await;
2267
2268    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2269        && let Some(capabilities) = update.capabilities.clone()
2270    {
2271        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2272            .await?;
2273    }
2274
2275    let project_connection_ids = db
2276        .project_connection_ids(project_id, session.connection_id, true)
2277        .await?;
2278    broadcast(
2279        Some(session.connection_id),
2280        project_connection_ids.iter().copied(),
2281        |connection_id| {
2282            session
2283                .peer
2284                .forward_send(session.connection_id, connection_id, request.clone())
2285        },
2286    );
2287    Ok(())
2288}
2289
2290/// forward a project request to the host. These requests should be read only
2291/// as guests are allowed to send them.
2292async fn forward_read_only_project_request<T>(
2293    request: T,
2294    response: Response<T>,
2295    session: MessageContext,
2296) -> Result<()>
2297where
2298    T: EntityMessage + RequestMessage,
2299{
2300    let project_id = ProjectId::from_proto(request.remote_entity_id());
2301    let host_connection_id = session
2302        .db()
2303        .await
2304        .host_for_read_only_project_request(project_id, session.connection_id)
2305        .await?;
2306    let payload = session.forward_request(host_connection_id, request).await?;
2307    response.send(payload)?;
2308    Ok(())
2309}
2310
2311/// forward a project request to the host. These requests are disallowed
2312/// for guests.
2313async fn forward_mutating_project_request<T>(
2314    request: T,
2315    response: Response<T>,
2316    session: MessageContext,
2317) -> Result<()>
2318where
2319    T: EntityMessage + RequestMessage,
2320{
2321    let project_id = ProjectId::from_proto(request.remote_entity_id());
2322
2323    let host_connection_id = session
2324        .db()
2325        .await
2326        .host_for_mutating_project_request(project_id, session.connection_id)
2327        .await?;
2328    let payload = session.forward_request(host_connection_id, request).await?;
2329    response.send(payload)?;
2330    Ok(())
2331}
2332
2333async fn lsp_query(
2334    request: proto::LspQuery,
2335    response: Response<proto::LspQuery>,
2336    session: MessageContext,
2337) -> Result<()> {
2338    let (name, should_write) = request.query_name_and_write_permissions();
2339    tracing::Span::current().record("lsp_query_request", name);
2340    tracing::info!("lsp_query message received");
2341    if should_write {
2342        forward_mutating_project_request(request, response, session).await
2343    } else {
2344        forward_read_only_project_request(request, response, session).await
2345    }
2346}
2347
2348/// Notify other participants that a new buffer has been created
2349async fn create_buffer_for_peer(
2350    request: proto::CreateBufferForPeer,
2351    session: MessageContext,
2352) -> Result<()> {
2353    session
2354        .db()
2355        .await
2356        .check_user_is_project_host(
2357            ProjectId::from_proto(request.project_id),
2358            session.connection_id,
2359        )
2360        .await?;
2361    let peer_id = request.peer_id.context("invalid peer id")?;
2362    session
2363        .peer
2364        .forward_send(session.connection_id, peer_id.into(), request)?;
2365    Ok(())
2366}
2367
2368/// Notify other participants that a new image has been created
2369async fn create_image_for_peer(
2370    request: proto::CreateImageForPeer,
2371    session: MessageContext,
2372) -> Result<()> {
2373    session
2374        .db()
2375        .await
2376        .check_user_is_project_host(
2377            ProjectId::from_proto(request.project_id),
2378            session.connection_id,
2379        )
2380        .await?;
2381    let peer_id = request.peer_id.context("invalid peer id")?;
2382    session
2383        .peer
2384        .forward_send(session.connection_id, peer_id.into(), request)?;
2385    Ok(())
2386}
2387
2388/// Notify other participants that a buffer has been updated. This is
2389/// allowed for guests as long as the update is limited to selections.
2390async fn update_buffer(
2391    request: proto::UpdateBuffer,
2392    response: Response<proto::UpdateBuffer>,
2393    session: MessageContext,
2394) -> Result<()> {
2395    let project_id = ProjectId::from_proto(request.project_id);
2396    let mut capability = Capability::ReadOnly;
2397
2398    for op in request.operations.iter() {
2399        match op.variant {
2400            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2401            Some(_) => capability = Capability::ReadWrite,
2402        }
2403    }
2404
2405    let host = {
2406        let guard = session
2407            .db()
2408            .await
2409            .connections_for_buffer_update(project_id, session.connection_id, capability)
2410            .await?;
2411
2412        let (host, guests) = &*guard;
2413
2414        broadcast(
2415            Some(session.connection_id),
2416            guests.clone(),
2417            |connection_id| {
2418                session
2419                    .peer
2420                    .forward_send(session.connection_id, connection_id, request.clone())
2421            },
2422        );
2423
2424        *host
2425    };
2426
2427    if host != session.connection_id {
2428        session.forward_request(host, request.clone()).await?;
2429    }
2430
2431    response.send(proto::Ack {})?;
2432    Ok(())
2433}
2434
2435async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2436    let project_id = ProjectId::from_proto(message.project_id);
2437
2438    let operation = message.operation.as_ref().context("invalid operation")?;
2439    let capability = match operation.variant.as_ref() {
2440        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2441            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2442                match buffer_op.variant {
2443                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2444                        Capability::ReadOnly
2445                    }
2446                    _ => Capability::ReadWrite,
2447                }
2448            } else {
2449                Capability::ReadWrite
2450            }
2451        }
2452        Some(_) => Capability::ReadWrite,
2453        None => Capability::ReadOnly,
2454    };
2455
2456    let guard = session
2457        .db()
2458        .await
2459        .connections_for_buffer_update(project_id, session.connection_id, capability)
2460        .await?;
2461
2462    let (host, guests) = &*guard;
2463
2464    broadcast(
2465        Some(session.connection_id),
2466        guests.iter().chain([host]).copied(),
2467        |connection_id| {
2468            session
2469                .peer
2470                .forward_send(session.connection_id, connection_id, message.clone())
2471        },
2472    );
2473
2474    Ok(())
2475}
2476
2477/// Notify other participants that a project has been updated.
2478async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2479    request: T,
2480    session: MessageContext,
2481) -> Result<()> {
2482    let project_id = ProjectId::from_proto(request.remote_entity_id());
2483    let project_connection_ids = session
2484        .db()
2485        .await
2486        .project_connection_ids(project_id, session.connection_id, false)
2487        .await?;
2488
2489    broadcast(
2490        Some(session.connection_id),
2491        project_connection_ids.iter().copied(),
2492        |connection_id| {
2493            session
2494                .peer
2495                .forward_send(session.connection_id, connection_id, request.clone())
2496        },
2497    );
2498    Ok(())
2499}
2500
2501/// Start following another user in a call.
2502async fn follow(
2503    request: proto::Follow,
2504    response: Response<proto::Follow>,
2505    session: MessageContext,
2506) -> Result<()> {
2507    let room_id = RoomId::from_proto(request.room_id);
2508    let project_id = request.project_id.map(ProjectId::from_proto);
2509    let leader_id = request.leader_id.context("invalid leader id")?.into();
2510    let follower_id = session.connection_id;
2511
2512    session
2513        .db()
2514        .await
2515        .check_room_participants(room_id, leader_id, session.connection_id)
2516        .await?;
2517
2518    let response_payload = session.forward_request(leader_id, request).await?;
2519    response.send(response_payload)?;
2520
2521    if let Some(project_id) = project_id {
2522        let room = session
2523            .db()
2524            .await
2525            .follow(room_id, project_id, leader_id, follower_id)
2526            .await?;
2527        room_updated(&room, &session.peer);
2528    }
2529
2530    Ok(())
2531}
2532
2533/// Stop following another user in a call.
2534async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2535    let room_id = RoomId::from_proto(request.room_id);
2536    let project_id = request.project_id.map(ProjectId::from_proto);
2537    let leader_id = request.leader_id.context("invalid leader id")?.into();
2538    let follower_id = session.connection_id;
2539
2540    session
2541        .db()
2542        .await
2543        .check_room_participants(room_id, leader_id, session.connection_id)
2544        .await?;
2545
2546    session
2547        .peer
2548        .forward_send(session.connection_id, leader_id, request)?;
2549
2550    if let Some(project_id) = project_id {
2551        let room = session
2552            .db()
2553            .await
2554            .unfollow(room_id, project_id, leader_id, follower_id)
2555            .await?;
2556        room_updated(&room, &session.peer);
2557    }
2558
2559    Ok(())
2560}
2561
2562/// Notify everyone following you of your current location.
2563async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2564    let room_id = RoomId::from_proto(request.room_id);
2565    let database = session.db.lock().await;
2566
2567    let connection_ids = if let Some(project_id) = request.project_id {
2568        let project_id = ProjectId::from_proto(project_id);
2569        database
2570            .project_connection_ids(project_id, session.connection_id, true)
2571            .await?
2572    } else {
2573        database
2574            .room_connection_ids(room_id, session.connection_id)
2575            .await?
2576    };
2577
2578    // For now, don't send view update messages back to that view's current leader.
2579    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2580        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2581        _ => None,
2582    });
2583
2584    for connection_id in connection_ids.iter().cloned() {
2585        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2586            session
2587                .peer
2588                .forward_send(session.connection_id, connection_id, request.clone())?;
2589        }
2590    }
2591    Ok(())
2592}
2593
2594/// Get public data about users.
2595async fn get_users(
2596    request: proto::GetUsers,
2597    response: Response<proto::GetUsers>,
2598    session: MessageContext,
2599) -> Result<()> {
2600    let user_ids = request
2601        .user_ids
2602        .into_iter()
2603        .map(UserId::from_proto)
2604        .collect();
2605    let users = session
2606        .db()
2607        .await
2608        .get_users_by_ids(user_ids)
2609        .await?
2610        .into_iter()
2611        .map(|user| proto::User {
2612            id: user.id.to_proto(),
2613            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2614            github_login: user.github_login,
2615            name: user.name,
2616        })
2617        .collect();
2618    response.send(proto::UsersResponse { users })?;
2619    Ok(())
2620}
2621
2622/// Search for users (to invite) buy Github login
2623async fn fuzzy_search_users(
2624    request: proto::FuzzySearchUsers,
2625    response: Response<proto::FuzzySearchUsers>,
2626    session: MessageContext,
2627) -> Result<()> {
2628    let query = request.query;
2629    let users = match query.len() {
2630        0 => vec![],
2631        1 | 2 => session
2632            .db()
2633            .await
2634            .get_user_by_github_login(&query)
2635            .await?
2636            .into_iter()
2637            .collect(),
2638        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2639    };
2640    let users = users
2641        .into_iter()
2642        .filter(|user| user.id != session.user_id())
2643        .map(|user| proto::User {
2644            id: user.id.to_proto(),
2645            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2646            github_login: user.github_login,
2647            name: user.name,
2648        })
2649        .collect();
2650    response.send(proto::UsersResponse { users })?;
2651    Ok(())
2652}
2653
2654/// Send a contact request to another user.
2655async fn request_contact(
2656    request: proto::RequestContact,
2657    response: Response<proto::RequestContact>,
2658    session: MessageContext,
2659) -> Result<()> {
2660    let requester_id = session.user_id();
2661    let responder_id = UserId::from_proto(request.responder_id);
2662    if requester_id == responder_id {
2663        return Err(anyhow!("cannot add yourself as a contact"))?;
2664    }
2665
2666    let notifications = session
2667        .db()
2668        .await
2669        .send_contact_request(requester_id, responder_id)
2670        .await?;
2671
2672    // Update outgoing contact requests of requester
2673    let mut update = proto::UpdateContacts::default();
2674    update.outgoing_requests.push(responder_id.to_proto());
2675    for connection_id in session
2676        .connection_pool()
2677        .await
2678        .user_connection_ids(requester_id)
2679    {
2680        session.peer.send(connection_id, update.clone())?;
2681    }
2682
2683    // Update incoming contact requests of responder
2684    let mut update = proto::UpdateContacts::default();
2685    update
2686        .incoming_requests
2687        .push(proto::IncomingContactRequest {
2688            requester_id: requester_id.to_proto(),
2689        });
2690    let connection_pool = session.connection_pool().await;
2691    for connection_id in connection_pool.user_connection_ids(responder_id) {
2692        session.peer.send(connection_id, update.clone())?;
2693    }
2694
2695    send_notifications(&connection_pool, &session.peer, notifications);
2696
2697    response.send(proto::Ack {})?;
2698    Ok(())
2699}
2700
2701/// Accept or decline a contact request
2702async fn respond_to_contact_request(
2703    request: proto::RespondToContactRequest,
2704    response: Response<proto::RespondToContactRequest>,
2705    session: MessageContext,
2706) -> Result<()> {
2707    let responder_id = session.user_id();
2708    let requester_id = UserId::from_proto(request.requester_id);
2709    let db = session.db().await;
2710    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2711        db.dismiss_contact_notification(responder_id, requester_id)
2712            .await?;
2713    } else {
2714        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2715
2716        let notifications = db
2717            .respond_to_contact_request(responder_id, requester_id, accept)
2718            .await?;
2719        let requester_busy = db.is_user_busy(requester_id).await?;
2720        let responder_busy = db.is_user_busy(responder_id).await?;
2721
2722        let pool = session.connection_pool().await;
2723        // Update responder with new contact
2724        let mut update = proto::UpdateContacts::default();
2725        if accept {
2726            update
2727                .contacts
2728                .push(contact_for_user(requester_id, requester_busy, &pool));
2729        }
2730        update
2731            .remove_incoming_requests
2732            .push(requester_id.to_proto());
2733        for connection_id in pool.user_connection_ids(responder_id) {
2734            session.peer.send(connection_id, update.clone())?;
2735        }
2736
2737        // Update requester with new contact
2738        let mut update = proto::UpdateContacts::default();
2739        if accept {
2740            update
2741                .contacts
2742                .push(contact_for_user(responder_id, responder_busy, &pool));
2743        }
2744        update
2745            .remove_outgoing_requests
2746            .push(responder_id.to_proto());
2747
2748        for connection_id in pool.user_connection_ids(requester_id) {
2749            session.peer.send(connection_id, update.clone())?;
2750        }
2751
2752        send_notifications(&pool, &session.peer, notifications);
2753    }
2754
2755    response.send(proto::Ack {})?;
2756    Ok(())
2757}
2758
2759/// Remove a contact.
2760async fn remove_contact(
2761    request: proto::RemoveContact,
2762    response: Response<proto::RemoveContact>,
2763    session: MessageContext,
2764) -> Result<()> {
2765    let requester_id = session.user_id();
2766    let responder_id = UserId::from_proto(request.user_id);
2767    let db = session.db().await;
2768    let (contact_accepted, deleted_notification_id) =
2769        db.remove_contact(requester_id, responder_id).await?;
2770
2771    let pool = session.connection_pool().await;
2772    // Update outgoing contact requests of requester
2773    let mut update = proto::UpdateContacts::default();
2774    if contact_accepted {
2775        update.remove_contacts.push(responder_id.to_proto());
2776    } else {
2777        update
2778            .remove_outgoing_requests
2779            .push(responder_id.to_proto());
2780    }
2781    for connection_id in pool.user_connection_ids(requester_id) {
2782        session.peer.send(connection_id, update.clone())?;
2783    }
2784
2785    // Update incoming contact requests of responder
2786    let mut update = proto::UpdateContacts::default();
2787    if contact_accepted {
2788        update.remove_contacts.push(requester_id.to_proto());
2789    } else {
2790        update
2791            .remove_incoming_requests
2792            .push(requester_id.to_proto());
2793    }
2794    for connection_id in pool.user_connection_ids(responder_id) {
2795        session.peer.send(connection_id, update.clone())?;
2796        if let Some(notification_id) = deleted_notification_id {
2797            session.peer.send(
2798                connection_id,
2799                proto::DeleteNotification {
2800                    notification_id: notification_id.to_proto(),
2801                },
2802            )?;
2803        }
2804    }
2805
2806    response.send(proto::Ack {})?;
2807    Ok(())
2808}
2809
2810fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2811    version.0.minor < 139
2812}
2813
2814async fn subscribe_to_channels(
2815    _: proto::SubscribeToChannels,
2816    session: MessageContext,
2817) -> Result<()> {
2818    subscribe_user_to_channels(session.user_id(), &session).await?;
2819    Ok(())
2820}
2821
2822async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2823    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2824    let mut pool = session.connection_pool().await;
2825    for membership in &channels_for_user.channel_memberships {
2826        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2827    }
2828    session.peer.send(
2829        session.connection_id,
2830        build_update_user_channels(&channels_for_user),
2831    )?;
2832    session.peer.send(
2833        session.connection_id,
2834        build_channels_update(channels_for_user),
2835    )?;
2836    Ok(())
2837}
2838
2839/// Creates a new channel.
2840async fn create_channel(
2841    request: proto::CreateChannel,
2842    response: Response<proto::CreateChannel>,
2843    session: MessageContext,
2844) -> Result<()> {
2845    let db = session.db().await;
2846
2847    let parent_id = request.parent_id.map(ChannelId::from_proto);
2848    let (channel, membership) = db
2849        .create_channel(&request.name, parent_id, session.user_id())
2850        .await?;
2851
2852    let root_id = channel.root_id();
2853    let channel = Channel::from_model(channel);
2854
2855    response.send(proto::CreateChannelResponse {
2856        channel: Some(channel.to_proto()),
2857        parent_id: request.parent_id,
2858    })?;
2859
2860    let mut connection_pool = session.connection_pool().await;
2861    if let Some(membership) = membership {
2862        connection_pool.subscribe_to_channel(
2863            membership.user_id,
2864            membership.channel_id,
2865            membership.role,
2866        );
2867        let update = proto::UpdateUserChannels {
2868            channel_memberships: vec![proto::ChannelMembership {
2869                channel_id: membership.channel_id.to_proto(),
2870                role: membership.role.into(),
2871            }],
2872            ..Default::default()
2873        };
2874        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2875            session.peer.send(connection_id, update.clone())?;
2876        }
2877    }
2878
2879    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2880        if !role.can_see_channel(channel.visibility) {
2881            continue;
2882        }
2883
2884        let update = proto::UpdateChannels {
2885            channels: vec![channel.to_proto()],
2886            ..Default::default()
2887        };
2888        session.peer.send(connection_id, update.clone())?;
2889    }
2890
2891    Ok(())
2892}
2893
2894/// Delete a channel
2895async fn delete_channel(
2896    request: proto::DeleteChannel,
2897    response: Response<proto::DeleteChannel>,
2898    session: MessageContext,
2899) -> Result<()> {
2900    let db = session.db().await;
2901
2902    let channel_id = request.channel_id;
2903    let (root_channel, removed_channels) = db
2904        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2905        .await?;
2906    response.send(proto::Ack {})?;
2907
2908    // Notify members of removed channels
2909    let mut update = proto::UpdateChannels::default();
2910    update
2911        .delete_channels
2912        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2913
2914    let connection_pool = session.connection_pool().await;
2915    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2916        session.peer.send(connection_id, update.clone())?;
2917    }
2918
2919    Ok(())
2920}
2921
2922/// Invite someone to join a channel.
2923async fn invite_channel_member(
2924    request: proto::InviteChannelMember,
2925    response: Response<proto::InviteChannelMember>,
2926    session: MessageContext,
2927) -> Result<()> {
2928    let db = session.db().await;
2929    let channel_id = ChannelId::from_proto(request.channel_id);
2930    let invitee_id = UserId::from_proto(request.user_id);
2931    let InviteMemberResult {
2932        channel,
2933        notifications,
2934    } = db
2935        .invite_channel_member(
2936            channel_id,
2937            invitee_id,
2938            session.user_id(),
2939            request.role().into(),
2940        )
2941        .await?;
2942
2943    let update = proto::UpdateChannels {
2944        channel_invitations: vec![channel.to_proto()],
2945        ..Default::default()
2946    };
2947
2948    let connection_pool = session.connection_pool().await;
2949    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2950        session.peer.send(connection_id, update.clone())?;
2951    }
2952
2953    send_notifications(&connection_pool, &session.peer, notifications);
2954
2955    response.send(proto::Ack {})?;
2956    Ok(())
2957}
2958
2959/// remove someone from a channel
2960async fn remove_channel_member(
2961    request: proto::RemoveChannelMember,
2962    response: Response<proto::RemoveChannelMember>,
2963    session: MessageContext,
2964) -> Result<()> {
2965    let db = session.db().await;
2966    let channel_id = ChannelId::from_proto(request.channel_id);
2967    let member_id = UserId::from_proto(request.user_id);
2968
2969    let RemoveChannelMemberResult {
2970        membership_update,
2971        notification_id,
2972    } = db
2973        .remove_channel_member(channel_id, member_id, session.user_id())
2974        .await?;
2975
2976    let mut connection_pool = session.connection_pool().await;
2977    notify_membership_updated(
2978        &mut connection_pool,
2979        membership_update,
2980        member_id,
2981        &session.peer,
2982    );
2983    for connection_id in connection_pool.user_connection_ids(member_id) {
2984        if let Some(notification_id) = notification_id {
2985            session
2986                .peer
2987                .send(
2988                    connection_id,
2989                    proto::DeleteNotification {
2990                        notification_id: notification_id.to_proto(),
2991                    },
2992                )
2993                .trace_err();
2994        }
2995    }
2996
2997    response.send(proto::Ack {})?;
2998    Ok(())
2999}
3000
3001/// Toggle the channel between public and private.
3002/// Care is taken to maintain the invariant that public channels only descend from public channels,
3003/// (though members-only channels can appear at any point in the hierarchy).
3004async fn set_channel_visibility(
3005    request: proto::SetChannelVisibility,
3006    response: Response<proto::SetChannelVisibility>,
3007    session: MessageContext,
3008) -> Result<()> {
3009    let db = session.db().await;
3010    let channel_id = ChannelId::from_proto(request.channel_id);
3011    let visibility = request.visibility().into();
3012
3013    let channel_model = db
3014        .set_channel_visibility(channel_id, visibility, session.user_id())
3015        .await?;
3016    let root_id = channel_model.root_id();
3017    let channel = Channel::from_model(channel_model);
3018
3019    let mut connection_pool = session.connection_pool().await;
3020    for (user_id, role) in connection_pool
3021        .channel_user_ids(root_id)
3022        .collect::<Vec<_>>()
3023        .into_iter()
3024    {
3025        let update = if role.can_see_channel(channel.visibility) {
3026            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3027            proto::UpdateChannels {
3028                channels: vec![channel.to_proto()],
3029                ..Default::default()
3030            }
3031        } else {
3032            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3033            proto::UpdateChannels {
3034                delete_channels: vec![channel.id.to_proto()],
3035                ..Default::default()
3036            }
3037        };
3038
3039        for connection_id in connection_pool.user_connection_ids(user_id) {
3040            session.peer.send(connection_id, update.clone())?;
3041        }
3042    }
3043
3044    response.send(proto::Ack {})?;
3045    Ok(())
3046}
3047
3048/// Alter the role for a user in the channel.
3049async fn set_channel_member_role(
3050    request: proto::SetChannelMemberRole,
3051    response: Response<proto::SetChannelMemberRole>,
3052    session: MessageContext,
3053) -> Result<()> {
3054    let db = session.db().await;
3055    let channel_id = ChannelId::from_proto(request.channel_id);
3056    let member_id = UserId::from_proto(request.user_id);
3057    let result = db
3058        .set_channel_member_role(
3059            channel_id,
3060            session.user_id(),
3061            member_id,
3062            request.role().into(),
3063        )
3064        .await?;
3065
3066    match result {
3067        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3068            let mut connection_pool = session.connection_pool().await;
3069            notify_membership_updated(
3070                &mut connection_pool,
3071                membership_update,
3072                member_id,
3073                &session.peer,
3074            )
3075        }
3076        db::SetMemberRoleResult::InviteUpdated(channel) => {
3077            let update = proto::UpdateChannels {
3078                channel_invitations: vec![channel.to_proto()],
3079                ..Default::default()
3080            };
3081
3082            for connection_id in session
3083                .connection_pool()
3084                .await
3085                .user_connection_ids(member_id)
3086            {
3087                session.peer.send(connection_id, update.clone())?;
3088            }
3089        }
3090    }
3091
3092    response.send(proto::Ack {})?;
3093    Ok(())
3094}
3095
3096/// Change the name of a channel
3097async fn rename_channel(
3098    request: proto::RenameChannel,
3099    response: Response<proto::RenameChannel>,
3100    session: MessageContext,
3101) -> Result<()> {
3102    let db = session.db().await;
3103    let channel_id = ChannelId::from_proto(request.channel_id);
3104    let channel_model = db
3105        .rename_channel(channel_id, session.user_id(), &request.name)
3106        .await?;
3107    let root_id = channel_model.root_id();
3108    let channel = Channel::from_model(channel_model);
3109
3110    response.send(proto::RenameChannelResponse {
3111        channel: Some(channel.to_proto()),
3112    })?;
3113
3114    let connection_pool = session.connection_pool().await;
3115    let update = proto::UpdateChannels {
3116        channels: vec![channel.to_proto()],
3117        ..Default::default()
3118    };
3119    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3120        if role.can_see_channel(channel.visibility) {
3121            session.peer.send(connection_id, update.clone())?;
3122        }
3123    }
3124
3125    Ok(())
3126}
3127
3128/// Move a channel to a new parent.
3129async fn move_channel(
3130    request: proto::MoveChannel,
3131    response: Response<proto::MoveChannel>,
3132    session: MessageContext,
3133) -> Result<()> {
3134    let channel_id = ChannelId::from_proto(request.channel_id);
3135    let to = ChannelId::from_proto(request.to);
3136
3137    let (root_id, channels) = session
3138        .db()
3139        .await
3140        .move_channel(channel_id, to, session.user_id())
3141        .await?;
3142
3143    let connection_pool = session.connection_pool().await;
3144    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3145        let channels = channels
3146            .iter()
3147            .filter_map(|channel| {
3148                if role.can_see_channel(channel.visibility) {
3149                    Some(channel.to_proto())
3150                } else {
3151                    None
3152                }
3153            })
3154            .collect::<Vec<_>>();
3155        if channels.is_empty() {
3156            continue;
3157        }
3158
3159        let update = proto::UpdateChannels {
3160            channels,
3161            ..Default::default()
3162        };
3163
3164        session.peer.send(connection_id, update.clone())?;
3165    }
3166
3167    response.send(Ack {})?;
3168    Ok(())
3169}
3170
3171async fn reorder_channel(
3172    request: proto::ReorderChannel,
3173    response: Response<proto::ReorderChannel>,
3174    session: MessageContext,
3175) -> Result<()> {
3176    let channel_id = ChannelId::from_proto(request.channel_id);
3177    let direction = request.direction();
3178
3179    let updated_channels = session
3180        .db()
3181        .await
3182        .reorder_channel(channel_id, direction, session.user_id())
3183        .await?;
3184
3185    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3186        let connection_pool = session.connection_pool().await;
3187        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3188            let channels = updated_channels
3189                .iter()
3190                .filter_map(|channel| {
3191                    if role.can_see_channel(channel.visibility) {
3192                        Some(channel.to_proto())
3193                    } else {
3194                        None
3195                    }
3196                })
3197                .collect::<Vec<_>>();
3198
3199            if channels.is_empty() {
3200                continue;
3201            }
3202
3203            let update = proto::UpdateChannels {
3204                channels,
3205                ..Default::default()
3206            };
3207
3208            session.peer.send(connection_id, update.clone())?;
3209        }
3210    }
3211
3212    response.send(Ack {})?;
3213    Ok(())
3214}
3215
3216/// Get the list of channel members
3217async fn get_channel_members(
3218    request: proto::GetChannelMembers,
3219    response: Response<proto::GetChannelMembers>,
3220    session: MessageContext,
3221) -> Result<()> {
3222    let db = session.db().await;
3223    let channel_id = ChannelId::from_proto(request.channel_id);
3224    let limit = if request.limit == 0 {
3225        u16::MAX as u64
3226    } else {
3227        request.limit
3228    };
3229    let (members, users) = db
3230        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3231        .await?;
3232    response.send(proto::GetChannelMembersResponse { members, users })?;
3233    Ok(())
3234}
3235
3236/// Accept or decline a channel invitation.
3237async fn respond_to_channel_invite(
3238    request: proto::RespondToChannelInvite,
3239    response: Response<proto::RespondToChannelInvite>,
3240    session: MessageContext,
3241) -> Result<()> {
3242    let db = session.db().await;
3243    let channel_id = ChannelId::from_proto(request.channel_id);
3244    let RespondToChannelInvite {
3245        membership_update,
3246        notifications,
3247    } = db
3248        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3249        .await?;
3250
3251    let mut connection_pool = session.connection_pool().await;
3252    if let Some(membership_update) = membership_update {
3253        notify_membership_updated(
3254            &mut connection_pool,
3255            membership_update,
3256            session.user_id(),
3257            &session.peer,
3258        );
3259    } else {
3260        let update = proto::UpdateChannels {
3261            remove_channel_invitations: vec![channel_id.to_proto()],
3262            ..Default::default()
3263        };
3264
3265        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3266            session.peer.send(connection_id, update.clone())?;
3267        }
3268    };
3269
3270    send_notifications(&connection_pool, &session.peer, notifications);
3271
3272    response.send(proto::Ack {})?;
3273
3274    Ok(())
3275}
3276
3277/// Join the channels' room
3278async fn join_channel(
3279    request: proto::JoinChannel,
3280    response: Response<proto::JoinChannel>,
3281    session: MessageContext,
3282) -> Result<()> {
3283    let channel_id = ChannelId::from_proto(request.channel_id);
3284    join_channel_internal(channel_id, Box::new(response), session).await
3285}
3286
3287trait JoinChannelInternalResponse {
3288    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3289}
3290impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3291    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3292        Response::<proto::JoinChannel>::send(self, result)
3293    }
3294}
3295impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3296    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3297        Response::<proto::JoinRoom>::send(self, result)
3298    }
3299}
3300
3301async fn join_channel_internal(
3302    channel_id: ChannelId,
3303    response: Box<impl JoinChannelInternalResponse>,
3304    session: MessageContext,
3305) -> Result<()> {
3306    let joined_room = {
3307        let mut db = session.db().await;
3308        // If zed quits without leaving the room, and the user re-opens zed before the
3309        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3310        // room they were in.
3311        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3312            tracing::info!(
3313                stale_connection_id = %connection,
3314                "cleaning up stale connection",
3315            );
3316            drop(db);
3317            leave_room_for_session(&session, connection).await?;
3318            db = session.db().await;
3319        }
3320
3321        let (joined_room, membership_updated, role) = db
3322            .join_channel(channel_id, session.user_id(), session.connection_id)
3323            .await?;
3324
3325        let live_kit_connection_info =
3326            session
3327                .app_state
3328                .livekit_client
3329                .as_ref()
3330                .and_then(|live_kit| {
3331                    let (can_publish, token) = if role == ChannelRole::Guest {
3332                        (
3333                            false,
3334                            live_kit
3335                                .guest_token(
3336                                    &joined_room.room.livekit_room,
3337                                    &session.user_id().to_string(),
3338                                )
3339                                .trace_err()?,
3340                        )
3341                    } else {
3342                        (
3343                            true,
3344                            live_kit
3345                                .room_token(
3346                                    &joined_room.room.livekit_room,
3347                                    &session.user_id().to_string(),
3348                                )
3349                                .trace_err()?,
3350                        )
3351                    };
3352
3353                    Some(LiveKitConnectionInfo {
3354                        server_url: live_kit.url().into(),
3355                        token,
3356                        can_publish,
3357                    })
3358                });
3359
3360        response.send(proto::JoinRoomResponse {
3361            room: Some(joined_room.room.clone()),
3362            channel_id: joined_room
3363                .channel
3364                .as_ref()
3365                .map(|channel| channel.id.to_proto()),
3366            live_kit_connection_info,
3367        })?;
3368
3369        let mut connection_pool = session.connection_pool().await;
3370        if let Some(membership_updated) = membership_updated {
3371            notify_membership_updated(
3372                &mut connection_pool,
3373                membership_updated,
3374                session.user_id(),
3375                &session.peer,
3376            );
3377        }
3378
3379        room_updated(&joined_room.room, &session.peer);
3380
3381        joined_room
3382    };
3383
3384    channel_updated(
3385        &joined_room.channel.context("channel not returned")?,
3386        &joined_room.room,
3387        &session.peer,
3388        &*session.connection_pool().await,
3389    );
3390
3391    update_user_contacts(session.user_id(), &session).await?;
3392    Ok(())
3393}
3394
3395/// Start editing the channel notes
3396async fn join_channel_buffer(
3397    request: proto::JoinChannelBuffer,
3398    response: Response<proto::JoinChannelBuffer>,
3399    session: MessageContext,
3400) -> Result<()> {
3401    let db = session.db().await;
3402    let channel_id = ChannelId::from_proto(request.channel_id);
3403
3404    let open_response = db
3405        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3406        .await?;
3407
3408    let collaborators = open_response.collaborators.clone();
3409    response.send(open_response)?;
3410
3411    let update = UpdateChannelBufferCollaborators {
3412        channel_id: channel_id.to_proto(),
3413        collaborators: collaborators.clone(),
3414    };
3415    channel_buffer_updated(
3416        session.connection_id,
3417        collaborators
3418            .iter()
3419            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3420        &update,
3421        &session.peer,
3422    );
3423
3424    Ok(())
3425}
3426
3427/// Edit the channel notes
3428async fn update_channel_buffer(
3429    request: proto::UpdateChannelBuffer,
3430    session: MessageContext,
3431) -> Result<()> {
3432    let db = session.db().await;
3433    let channel_id = ChannelId::from_proto(request.channel_id);
3434
3435    let (collaborators, epoch, version) = db
3436        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3437        .await?;
3438
3439    channel_buffer_updated(
3440        session.connection_id,
3441        collaborators.clone(),
3442        &proto::UpdateChannelBuffer {
3443            channel_id: channel_id.to_proto(),
3444            operations: request.operations,
3445        },
3446        &session.peer,
3447    );
3448
3449    let pool = &*session.connection_pool().await;
3450
3451    let non_collaborators =
3452        pool.channel_connection_ids(channel_id)
3453            .filter_map(|(connection_id, _)| {
3454                if collaborators.contains(&connection_id) {
3455                    None
3456                } else {
3457                    Some(connection_id)
3458                }
3459            });
3460
3461    broadcast(None, non_collaborators, |peer_id| {
3462        session.peer.send(
3463            peer_id,
3464            proto::UpdateChannels {
3465                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3466                    channel_id: channel_id.to_proto(),
3467                    epoch: epoch as u64,
3468                    version: version.clone(),
3469                }],
3470                ..Default::default()
3471            },
3472        )
3473    });
3474
3475    Ok(())
3476}
3477
3478/// Rejoin the channel notes after a connection blip
3479async fn rejoin_channel_buffers(
3480    request: proto::RejoinChannelBuffers,
3481    response: Response<proto::RejoinChannelBuffers>,
3482    session: MessageContext,
3483) -> Result<()> {
3484    let db = session.db().await;
3485    let buffers = db
3486        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3487        .await?;
3488
3489    for rejoined_buffer in &buffers {
3490        let collaborators_to_notify = rejoined_buffer
3491            .buffer
3492            .collaborators
3493            .iter()
3494            .filter_map(|c| Some(c.peer_id?.into()));
3495        channel_buffer_updated(
3496            session.connection_id,
3497            collaborators_to_notify,
3498            &proto::UpdateChannelBufferCollaborators {
3499                channel_id: rejoined_buffer.buffer.channel_id,
3500                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3501            },
3502            &session.peer,
3503        );
3504    }
3505
3506    response.send(proto::RejoinChannelBuffersResponse {
3507        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3508    })?;
3509
3510    Ok(())
3511}
3512
3513/// Stop editing the channel notes
3514async fn leave_channel_buffer(
3515    request: proto::LeaveChannelBuffer,
3516    response: Response<proto::LeaveChannelBuffer>,
3517    session: MessageContext,
3518) -> Result<()> {
3519    let db = session.db().await;
3520    let channel_id = ChannelId::from_proto(request.channel_id);
3521
3522    let left_buffer = db
3523        .leave_channel_buffer(channel_id, session.connection_id)
3524        .await?;
3525
3526    response.send(Ack {})?;
3527
3528    channel_buffer_updated(
3529        session.connection_id,
3530        left_buffer.connections,
3531        &proto::UpdateChannelBufferCollaborators {
3532            channel_id: channel_id.to_proto(),
3533            collaborators: left_buffer.collaborators,
3534        },
3535        &session.peer,
3536    );
3537
3538    Ok(())
3539}
3540
3541fn channel_buffer_updated<T: EnvelopedMessage>(
3542    sender_id: ConnectionId,
3543    collaborators: impl IntoIterator<Item = ConnectionId>,
3544    message: &T,
3545    peer: &Peer,
3546) {
3547    broadcast(Some(sender_id), collaborators, |peer_id| {
3548        peer.send(peer_id, message.clone())
3549    });
3550}
3551
3552fn send_notifications(
3553    connection_pool: &ConnectionPool,
3554    peer: &Peer,
3555    notifications: db::NotificationBatch,
3556) {
3557    for (user_id, notification) in notifications {
3558        for connection_id in connection_pool.user_connection_ids(user_id) {
3559            if let Err(error) = peer.send(
3560                connection_id,
3561                proto::AddNotification {
3562                    notification: Some(notification.clone()),
3563                },
3564            ) {
3565                tracing::error!(
3566                    "failed to send notification to {:?} {}",
3567                    connection_id,
3568                    error
3569                );
3570            }
3571        }
3572    }
3573}
3574
3575/// Send a message to the channel
3576async fn send_channel_message(
3577    _request: proto::SendChannelMessage,
3578    _response: Response<proto::SendChannelMessage>,
3579    _session: MessageContext,
3580) -> Result<()> {
3581    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3582}
3583
3584/// Delete a channel message
3585async fn remove_channel_message(
3586    _request: proto::RemoveChannelMessage,
3587    _response: Response<proto::RemoveChannelMessage>,
3588    _session: MessageContext,
3589) -> Result<()> {
3590    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3591}
3592
3593async fn update_channel_message(
3594    _request: proto::UpdateChannelMessage,
3595    _response: Response<proto::UpdateChannelMessage>,
3596    _session: MessageContext,
3597) -> Result<()> {
3598    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3599}
3600
3601/// Mark a channel message as read
3602async fn acknowledge_channel_message(
3603    _request: proto::AckChannelMessage,
3604    _session: MessageContext,
3605) -> Result<()> {
3606    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3607}
3608
3609/// Mark a buffer version as synced
3610async fn acknowledge_buffer_version(
3611    request: proto::AckBufferOperation,
3612    session: MessageContext,
3613) -> Result<()> {
3614    let buffer_id = BufferId::from_proto(request.buffer_id);
3615    session
3616        .db()
3617        .await
3618        .observe_buffer_version(
3619            buffer_id,
3620            session.user_id(),
3621            request.epoch as i32,
3622            &request.version,
3623        )
3624        .await?;
3625    Ok(())
3626}
3627
3628/// Start receiving chat updates for a channel
3629async fn join_channel_chat(
3630    _request: proto::JoinChannelChat,
3631    _response: Response<proto::JoinChannelChat>,
3632    _session: MessageContext,
3633) -> Result<()> {
3634    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3635}
3636
3637/// Stop receiving chat updates for a channel
3638async fn leave_channel_chat(
3639    _request: proto::LeaveChannelChat,
3640    _session: MessageContext,
3641) -> Result<()> {
3642    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3643}
3644
3645/// Retrieve the chat history for a channel
3646async fn get_channel_messages(
3647    _request: proto::GetChannelMessages,
3648    _response: Response<proto::GetChannelMessages>,
3649    _session: MessageContext,
3650) -> Result<()> {
3651    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3652}
3653
3654/// Retrieve specific chat messages
3655async fn get_channel_messages_by_id(
3656    _request: proto::GetChannelMessagesById,
3657    _response: Response<proto::GetChannelMessagesById>,
3658    _session: MessageContext,
3659) -> Result<()> {
3660    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3661}
3662
3663/// Retrieve the current users notifications
3664async fn get_notifications(
3665    request: proto::GetNotifications,
3666    response: Response<proto::GetNotifications>,
3667    session: MessageContext,
3668) -> Result<()> {
3669    let notifications = session
3670        .db()
3671        .await
3672        .get_notifications(
3673            session.user_id(),
3674            NOTIFICATION_COUNT_PER_PAGE,
3675            request.before_id.map(db::NotificationId::from_proto),
3676        )
3677        .await?;
3678    response.send(proto::GetNotificationsResponse {
3679        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3680        notifications,
3681    })?;
3682    Ok(())
3683}
3684
3685/// Mark notifications as read
3686async fn mark_notification_as_read(
3687    request: proto::MarkNotificationRead,
3688    response: Response<proto::MarkNotificationRead>,
3689    session: MessageContext,
3690) -> Result<()> {
3691    let database = &session.db().await;
3692    let notifications = database
3693        .mark_notification_as_read_by_id(
3694            session.user_id(),
3695            NotificationId::from_proto(request.notification_id),
3696        )
3697        .await?;
3698    send_notifications(
3699        &*session.connection_pool().await,
3700        &session.peer,
3701        notifications,
3702    );
3703    response.send(proto::Ack {})?;
3704    Ok(())
3705}
3706
3707fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3708    let message = match message {
3709        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3710        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3711        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3712        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3713        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3714            code: frame.code.into(),
3715            reason: frame.reason.as_str().to_owned().into(),
3716        })),
3717        // We should never receive a frame while reading the message, according
3718        // to the `tungstenite` maintainers:
3719        //
3720        // > It cannot occur when you read messages from the WebSocket, but it
3721        // > can be used when you want to send the raw frames (e.g. you want to
3722        // > send the frames to the WebSocket without composing the full message first).
3723        // >
3724        // > — https://github.com/snapview/tungstenite-rs/issues/268
3725        TungsteniteMessage::Frame(_) => {
3726            bail!("received an unexpected frame while reading the message")
3727        }
3728    };
3729
3730    Ok(message)
3731}
3732
3733fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3734    match message {
3735        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3736        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3737        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3738        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3739        AxumMessage::Close(frame) => {
3740            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3741                code: frame.code.into(),
3742                reason: frame.reason.as_ref().into(),
3743            }))
3744        }
3745    }
3746}
3747
3748fn notify_membership_updated(
3749    connection_pool: &mut ConnectionPool,
3750    result: MembershipUpdated,
3751    user_id: UserId,
3752    peer: &Peer,
3753) {
3754    for membership in &result.new_channels.channel_memberships {
3755        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3756    }
3757    for channel_id in &result.removed_channels {
3758        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3759    }
3760
3761    let user_channels_update = proto::UpdateUserChannels {
3762        channel_memberships: result
3763            .new_channels
3764            .channel_memberships
3765            .iter()
3766            .map(|cm| proto::ChannelMembership {
3767                channel_id: cm.channel_id.to_proto(),
3768                role: cm.role.into(),
3769            })
3770            .collect(),
3771        ..Default::default()
3772    };
3773
3774    let mut update = build_channels_update(result.new_channels);
3775    update.delete_channels = result
3776        .removed_channels
3777        .into_iter()
3778        .map(|id| id.to_proto())
3779        .collect();
3780    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3781
3782    for connection_id in connection_pool.user_connection_ids(user_id) {
3783        peer.send(connection_id, user_channels_update.clone())
3784            .trace_err();
3785        peer.send(connection_id, update.clone()).trace_err();
3786    }
3787}
3788
3789fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3790    proto::UpdateUserChannels {
3791        channel_memberships: channels
3792            .channel_memberships
3793            .iter()
3794            .map(|m| proto::ChannelMembership {
3795                channel_id: m.channel_id.to_proto(),
3796                role: m.role.into(),
3797            })
3798            .collect(),
3799        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3800    }
3801}
3802
3803fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3804    let mut update = proto::UpdateChannels::default();
3805
3806    for channel in channels.channels {
3807        update.channels.push(channel.to_proto());
3808    }
3809
3810    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3811
3812    for (channel_id, participants) in channels.channel_participants {
3813        update
3814            .channel_participants
3815            .push(proto::ChannelParticipants {
3816                channel_id: channel_id.to_proto(),
3817                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3818            });
3819    }
3820
3821    for channel in channels.invited_channels {
3822        update.channel_invitations.push(channel.to_proto());
3823    }
3824
3825    update
3826}
3827
3828fn build_initial_contacts_update(
3829    contacts: Vec<db::Contact>,
3830    pool: &ConnectionPool,
3831) -> proto::UpdateContacts {
3832    let mut update = proto::UpdateContacts::default();
3833
3834    for contact in contacts {
3835        match contact {
3836            db::Contact::Accepted { user_id, busy } => {
3837                update.contacts.push(contact_for_user(user_id, busy, pool));
3838            }
3839            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3840            db::Contact::Incoming { user_id } => {
3841                update
3842                    .incoming_requests
3843                    .push(proto::IncomingContactRequest {
3844                        requester_id: user_id.to_proto(),
3845                    })
3846            }
3847        }
3848    }
3849
3850    update
3851}
3852
3853fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3854    proto::Contact {
3855        user_id: user_id.to_proto(),
3856        online: pool.is_user_online(user_id),
3857        busy,
3858    }
3859}
3860
3861fn room_updated(room: &proto::Room, peer: &Peer) {
3862    broadcast(
3863        None,
3864        room.participants
3865            .iter()
3866            .filter_map(|participant| Some(participant.peer_id?.into())),
3867        |peer_id| {
3868            peer.send(
3869                peer_id,
3870                proto::RoomUpdated {
3871                    room: Some(room.clone()),
3872                },
3873            )
3874        },
3875    );
3876}
3877
3878fn channel_updated(
3879    channel: &db::channel::Model,
3880    room: &proto::Room,
3881    peer: &Peer,
3882    pool: &ConnectionPool,
3883) {
3884    let participants = room
3885        .participants
3886        .iter()
3887        .map(|p| p.user_id)
3888        .collect::<Vec<_>>();
3889
3890    broadcast(
3891        None,
3892        pool.channel_connection_ids(channel.root_id())
3893            .filter_map(|(channel_id, role)| {
3894                role.can_see_channel(channel.visibility)
3895                    .then_some(channel_id)
3896            }),
3897        |peer_id| {
3898            peer.send(
3899                peer_id,
3900                proto::UpdateChannels {
3901                    channel_participants: vec![proto::ChannelParticipants {
3902                        channel_id: channel.id.to_proto(),
3903                        participant_user_ids: participants.clone(),
3904                    }],
3905                    ..Default::default()
3906                },
3907            )
3908        },
3909    );
3910}
3911
3912async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3913    let db = session.db().await;
3914
3915    let contacts = db.get_contacts(user_id).await?;
3916    let busy = db.is_user_busy(user_id).await?;
3917
3918    let pool = session.connection_pool().await;
3919    let updated_contact = contact_for_user(user_id, busy, &pool);
3920    for contact in contacts {
3921        if let db::Contact::Accepted {
3922            user_id: contact_user_id,
3923            ..
3924        } = contact
3925        {
3926            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3927                session
3928                    .peer
3929                    .send(
3930                        contact_conn_id,
3931                        proto::UpdateContacts {
3932                            contacts: vec![updated_contact.clone()],
3933                            remove_contacts: Default::default(),
3934                            incoming_requests: Default::default(),
3935                            remove_incoming_requests: Default::default(),
3936                            outgoing_requests: Default::default(),
3937                            remove_outgoing_requests: Default::default(),
3938                        },
3939                    )
3940                    .trace_err();
3941            }
3942        }
3943    }
3944    Ok(())
3945}
3946
3947async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3948    let mut contacts_to_update = HashSet::default();
3949
3950    let room_id;
3951    let canceled_calls_to_user_ids;
3952    let livekit_room;
3953    let delete_livekit_room;
3954    let room;
3955    let channel;
3956
3957    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3958        contacts_to_update.insert(session.user_id());
3959
3960        for project in left_room.left_projects.values() {
3961            project_left(project, session);
3962        }
3963
3964        room_id = RoomId::from_proto(left_room.room.id);
3965        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3966        livekit_room = mem::take(&mut left_room.room.livekit_room);
3967        delete_livekit_room = left_room.deleted;
3968        room = mem::take(&mut left_room.room);
3969        channel = mem::take(&mut left_room.channel);
3970
3971        room_updated(&room, &session.peer);
3972    } else {
3973        return Ok(());
3974    }
3975
3976    if let Some(channel) = channel {
3977        channel_updated(
3978            &channel,
3979            &room,
3980            &session.peer,
3981            &*session.connection_pool().await,
3982        );
3983    }
3984
3985    {
3986        let pool = session.connection_pool().await;
3987        for canceled_user_id in canceled_calls_to_user_ids {
3988            for connection_id in pool.user_connection_ids(canceled_user_id) {
3989                session
3990                    .peer
3991                    .send(
3992                        connection_id,
3993                        proto::CallCanceled {
3994                            room_id: room_id.to_proto(),
3995                        },
3996                    )
3997                    .trace_err();
3998            }
3999            contacts_to_update.insert(canceled_user_id);
4000        }
4001    }
4002
4003    for contact_user_id in contacts_to_update {
4004        update_user_contacts(contact_user_id, session).await?;
4005    }
4006
4007    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4008        live_kit
4009            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4010            .await
4011            .trace_err();
4012
4013        if delete_livekit_room {
4014            live_kit.delete_room(livekit_room).await.trace_err();
4015        }
4016    }
4017
4018    Ok(())
4019}
4020
4021async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4022    let left_channel_buffers = session
4023        .db()
4024        .await
4025        .leave_channel_buffers(session.connection_id)
4026        .await?;
4027
4028    for left_buffer in left_channel_buffers {
4029        channel_buffer_updated(
4030            session.connection_id,
4031            left_buffer.connections,
4032            &proto::UpdateChannelBufferCollaborators {
4033                channel_id: left_buffer.channel_id.to_proto(),
4034                collaborators: left_buffer.collaborators,
4035            },
4036            &session.peer,
4037        );
4038    }
4039
4040    Ok(())
4041}
4042
4043fn project_left(project: &db::LeftProject, session: &Session) {
4044    for connection_id in &project.connection_ids {
4045        if project.should_unshare {
4046            session
4047                .peer
4048                .send(
4049                    *connection_id,
4050                    proto::UnshareProject {
4051                        project_id: project.id.to_proto(),
4052                    },
4053                )
4054                .trace_err();
4055        } else {
4056            session
4057                .peer
4058                .send(
4059                    *connection_id,
4060                    proto::RemoveProjectCollaborator {
4061                        project_id: project.id.to_proto(),
4062                        peer_id: Some(session.connection_id.into()),
4063                    },
4064                )
4065                .trace_err();
4066        }
4067    }
4068}
4069
4070pub trait ResultExt {
4071    type Ok;
4072
4073    fn trace_err(self) -> Option<Self::Ok>;
4074}
4075
4076impl<T, E> ResultExt for Result<T, E>
4077where
4078    E: std::fmt::Debug,
4079{
4080    type Ok = T;
4081
4082    #[track_caller]
4083    fn trace_err(self) -> Option<T> {
4084        match self {
4085            Ok(value) => Some(value),
4086            Err(error) => {
4087                tracing::error!("{:?}", error);
4088                None
4089            }
4090        }
4091    }
4092}