rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
  10        UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use rpc::proto::split_repository_update;
  37use tracing::Span;
  38use util::paths::PathStyle;
  39
  40use futures::{
  41    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  42    stream::FuturesUnordered,
  43};
  44use prometheus::{IntGauge, register_int_gauge};
  45use rpc::{
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  49        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51};
  52use semver::Version;
  53use serde::{Serialize, Serializer};
  54use std::{
  55    any::TypeId,
  56    future::Future,
  57    marker::PhantomData,
  58    mem,
  59    net::SocketAddr,
  60    ops::{Deref, DerefMut},
  61    rc::Rc,
  62    sync::{
  63        Arc, OnceLock,
  64        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  65    },
  66    time::{Duration, Instant},
  67};
  68use tokio::sync::{Semaphore, watch};
  69use tower::ServiceBuilder;
  70use tracing::{
  71    Instrument,
  72    field::{self},
  73    info_span, instrument,
  74};
  75
  76pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  77
  78// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  79pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  80
  81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  82const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  83
  84static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  85
  86const TOTAL_DURATION_MS: &str = "total_duration_ms";
  87const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  88const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  89const HOST_WAITING_MS: &str = "host_waiting_ms";
  90
  91type MessageHandler =
  92    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  93
  94pub struct ConnectionGuard;
  95
  96impl ConnectionGuard {
  97    pub fn try_acquire() -> Result<Self, ()> {
  98        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  99        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 100            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 101            tracing::error!(
 102                "too many concurrent connections: {}",
 103                current_connections + 1
 104            );
 105            return Err(());
 106        }
 107        Ok(ConnectionGuard)
 108    }
 109}
 110
 111impl Drop for ConnectionGuard {
 112    fn drop(&mut self) {
 113        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 114    }
 115}
 116
 117struct Response<R> {
 118    peer: Arc<Peer>,
 119    receipt: Receipt<R>,
 120    responded: Arc<AtomicBool>,
 121}
 122
 123impl<R: RequestMessage> Response<R> {
 124    fn send(self, payload: R::Response) -> Result<()> {
 125        self.responded.store(true, SeqCst);
 126        self.peer.respond(self.receipt, payload)?;
 127        Ok(())
 128    }
 129}
 130
 131#[derive(Clone, Debug)]
 132pub enum Principal {
 133    User(User),
 134    Impersonated { user: User, admin: User },
 135}
 136
 137impl Principal {
 138    fn update_span(&self, span: &tracing::Span) {
 139        match &self {
 140            Principal::User(user) => {
 141                span.record("user_id", user.id.0);
 142                span.record("login", &user.github_login);
 143            }
 144            Principal::Impersonated { user, admin } => {
 145                span.record("user_id", user.id.0);
 146                span.record("login", &user.github_login);
 147                span.record("impersonator", &admin.github_login);
 148            }
 149        }
 150    }
 151}
 152
 153#[derive(Clone)]
 154struct MessageContext {
 155    session: Session,
 156    span: tracing::Span,
 157}
 158
 159impl Deref for MessageContext {
 160    type Target = Session;
 161
 162    fn deref(&self) -> &Self::Target {
 163        &self.session
 164    }
 165}
 166
 167impl MessageContext {
 168    pub fn forward_request<T: RequestMessage>(
 169        &self,
 170        receiver_id: ConnectionId,
 171        request: T,
 172    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 173        let request_start_time = Instant::now();
 174        let span = self.span.clone();
 175        tracing::info!("start forwarding request");
 176        self.peer
 177            .forward_request(self.connection_id, receiver_id, request)
 178            .inspect(move |_| {
 179                span.record(
 180                    HOST_WAITING_MS,
 181                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 182                );
 183            })
 184            .inspect_err(|_| tracing::error!("error forwarding request"))
 185            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 186    }
 187}
 188
 189#[derive(Clone)]
 190struct Session {
 191    principal: Principal,
 192    connection_id: ConnectionId,
 193    db: Arc<tokio::sync::Mutex<DbHandle>>,
 194    peer: Arc<Peer>,
 195    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 196    app_state: Arc<AppState>,
 197    /// The GeoIP country code for the user.
 198    #[allow(unused)]
 199    geoip_country_code: Option<String>,
 200    #[allow(unused)]
 201    system_id: Option<String>,
 202    _executor: Executor,
 203}
 204
 205impl Session {
 206    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 207        #[cfg(feature = "test-support")]
 208        tokio::task::yield_now().await;
 209        let guard = self.db.lock().await;
 210        #[cfg(feature = "test-support")]
 211        tokio::task::yield_now().await;
 212        guard
 213    }
 214
 215    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 216        #[cfg(feature = "test-support")]
 217        tokio::task::yield_now().await;
 218        let guard = self.connection_pool.lock();
 219        ConnectionPoolGuard {
 220            guard,
 221            _not_send: PhantomData,
 222        }
 223    }
 224
 225    #[expect(dead_code)]
 226    fn is_staff(&self) -> bool {
 227        match &self.principal {
 228            Principal::User(user) => user.admin,
 229            Principal::Impersonated { .. } => true,
 230        }
 231    }
 232
 233    fn user_id(&self) -> UserId {
 234        match &self.principal {
 235            Principal::User(user) => user.id,
 236            Principal::Impersonated { user, .. } => user.id,
 237        }
 238    }
 239}
 240
 241impl Debug for Session {
 242    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 243        let mut result = f.debug_struct("Session");
 244        match &self.principal {
 245            Principal::User(user) => {
 246                result.field("user", &user.github_login);
 247            }
 248            Principal::Impersonated { user, admin } => {
 249                result.field("user", &user.github_login);
 250                result.field("impersonator", &admin.github_login);
 251            }
 252        }
 253        result.field("connection_id", &self.connection_id).finish()
 254    }
 255}
 256
 257struct DbHandle(Arc<Database>);
 258
 259impl Deref for DbHandle {
 260    type Target = Database;
 261
 262    fn deref(&self) -> &Self::Target {
 263        self.0.as_ref()
 264    }
 265}
 266
 267pub struct Server {
 268    id: parking_lot::Mutex<ServerId>,
 269    peer: Arc<Peer>,
 270    pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 271    app_state: Arc<AppState>,
 272    handlers: HashMap<TypeId, MessageHandler>,
 273    teardown: watch::Sender<bool>,
 274}
 275
 276struct ConnectionPoolGuard<'a> {
 277    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 278    _not_send: PhantomData<Rc<()>>,
 279}
 280
 281#[derive(Serialize)]
 282pub struct ServerSnapshot<'a> {
 283    peer: &'a Peer,
 284    #[serde(serialize_with = "serialize_deref")]
 285    connection_pool: ConnectionPoolGuard<'a>,
 286}
 287
 288pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 289where
 290    S: Serializer,
 291    T: Deref<Target = U>,
 292    U: Serialize,
 293{
 294    Serialize::serialize(value.deref(), serializer)
 295}
 296
 297impl Server {
 298    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 299        let mut server = Self {
 300            id: parking_lot::Mutex::new(id),
 301            peer: Peer::new(id.0 as u32),
 302            app_state,
 303            connection_pool: Default::default(),
 304            handlers: Default::default(),
 305            teardown: watch::channel(false).0,
 306        };
 307
 308        server
 309            .add_request_handler(ping)
 310            .add_request_handler(create_room)
 311            .add_request_handler(join_room)
 312            .add_request_handler(rejoin_room)
 313            .add_request_handler(leave_room)
 314            .add_request_handler(set_room_participant_role)
 315            .add_request_handler(call)
 316            .add_request_handler(cancel_call)
 317            .add_message_handler(decline_call)
 318            .add_request_handler(update_participant_location)
 319            .add_request_handler(share_project)
 320            .add_message_handler(unshare_project)
 321            .add_request_handler(join_project)
 322            .add_message_handler(leave_project)
 323            .add_request_handler(update_project)
 324            .add_request_handler(update_worktree)
 325            .add_request_handler(update_repository)
 326            .add_request_handler(remove_repository)
 327            .add_message_handler(start_language_server)
 328            .add_message_handler(update_language_server)
 329            .add_message_handler(update_diagnostic_summary)
 330            .add_message_handler(update_worktree_settings)
 331            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 332            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 333            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 334            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 335            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 336            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 337            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 338            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 339            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 340            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 341            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 342            .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 344            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 345            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 346            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 347            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 348            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 349            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 350            .add_request_handler(
 351                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 352            )
 353            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 354            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 355            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 356            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 357            .add_request_handler(
 358                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 359            )
 360            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 361            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 362            .add_request_handler(
 363                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 364            )
 365            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 366            .add_request_handler(
 367                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 368            )
 369            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 370            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 371            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 372            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 373            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 374            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 375            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 376            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 377            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 378            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 379            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 380            .add_request_handler(
 381                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 382            )
 383            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 384            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 385            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 386            .add_request_handler(lsp_query)
 387            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 388            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 389            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 390            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 391            .add_message_handler(create_buffer_for_peer)
 392            .add_message_handler(create_image_for_peer)
 393            .add_request_handler(update_buffer)
 394            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 395            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 396            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 397            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 398            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 399            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 400            .add_message_handler(
 401                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 402            )
 403            .add_request_handler(get_users)
 404            .add_request_handler(fuzzy_search_users)
 405            .add_request_handler(request_contact)
 406            .add_request_handler(remove_contact)
 407            .add_request_handler(respond_to_contact_request)
 408            .add_message_handler(subscribe_to_channels)
 409            .add_request_handler(create_channel)
 410            .add_request_handler(delete_channel)
 411            .add_request_handler(invite_channel_member)
 412            .add_request_handler(remove_channel_member)
 413            .add_request_handler(set_channel_member_role)
 414            .add_request_handler(set_channel_visibility)
 415            .add_request_handler(rename_channel)
 416            .add_request_handler(join_channel_buffer)
 417            .add_request_handler(leave_channel_buffer)
 418            .add_message_handler(update_channel_buffer)
 419            .add_request_handler(rejoin_channel_buffers)
 420            .add_request_handler(get_channel_members)
 421            .add_request_handler(respond_to_channel_invite)
 422            .add_request_handler(join_channel)
 423            .add_request_handler(join_channel_chat)
 424            .add_message_handler(leave_channel_chat)
 425            .add_request_handler(send_channel_message)
 426            .add_request_handler(remove_channel_message)
 427            .add_request_handler(update_channel_message)
 428            .add_request_handler(get_channel_messages)
 429            .add_request_handler(get_channel_messages_by_id)
 430            .add_request_handler(get_notifications)
 431            .add_request_handler(mark_notification_as_read)
 432            .add_request_handler(move_channel)
 433            .add_request_handler(reorder_channel)
 434            .add_request_handler(follow)
 435            .add_message_handler(unfollow)
 436            .add_message_handler(update_followers)
 437            .add_message_handler(acknowledge_channel_message)
 438            .add_message_handler(acknowledge_buffer_version)
 439            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 440            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 441            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 442            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 443            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 444            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 445            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 446            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 447            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 448            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 449            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 450            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 451            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 452            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 453            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 454            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 455            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 456            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 457            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 458            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 459            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 460            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 461            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 462            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 463            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 464            .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
 465            .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
 466            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 467            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 468            .add_message_handler(update_context)
 469            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 470            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
 471            .add_request_handler(share_agent_thread)
 472            .add_request_handler(get_shared_agent_thread)
 473            .add_request_handler(forward_project_search_chunk);
 474
 475        Arc::new(server)
 476    }
 477
 478    pub async fn start(&self) -> Result<()> {
 479        let server_id = *self.id.lock();
 480        let app_state = self.app_state.clone();
 481        let peer = self.peer.clone();
 482        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 483        let pool = self.connection_pool.clone();
 484        let livekit_client = self.app_state.livekit_client.clone();
 485
 486        let span = info_span!("start server");
 487        self.app_state.executor.spawn_detached(
 488            async move {
 489                tracing::info!("waiting for cleanup timeout");
 490                timeout.await;
 491                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 492
 493                app_state
 494                    .db
 495                    .delete_stale_channel_chat_participants(
 496                        &app_state.config.zed_environment,
 497                        server_id,
 498                    )
 499                    .await
 500                    .trace_err();
 501
 502                if let Some((room_ids, channel_ids)) = app_state
 503                    .db
 504                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 505                    .await
 506                    .trace_err()
 507                {
 508                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 509                    tracing::info!(
 510                        stale_channel_buffer_count = channel_ids.len(),
 511                        "retrieved stale channel buffers"
 512                    );
 513
 514                    for channel_id in channel_ids {
 515                        if let Some(refreshed_channel_buffer) = app_state
 516                            .db
 517                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 518                            .await
 519                            .trace_err()
 520                        {
 521                            for connection_id in refreshed_channel_buffer.connection_ids {
 522                                peer.send(
 523                                    connection_id,
 524                                    proto::UpdateChannelBufferCollaborators {
 525                                        channel_id: channel_id.to_proto(),
 526                                        collaborators: refreshed_channel_buffer
 527                                            .collaborators
 528                                            .clone(),
 529                                    },
 530                                )
 531                                .trace_err();
 532                            }
 533                        }
 534                    }
 535
 536                    for room_id in room_ids {
 537                        let mut contacts_to_update = HashSet::default();
 538                        let mut canceled_calls_to_user_ids = Vec::new();
 539                        let mut livekit_room = String::new();
 540                        let mut delete_livekit_room = false;
 541
 542                        if let Some(mut refreshed_room) = app_state
 543                            .db
 544                            .clear_stale_room_participants(room_id, server_id)
 545                            .await
 546                            .trace_err()
 547                        {
 548                            tracing::info!(
 549                                room_id = room_id.0,
 550                                new_participant_count = refreshed_room.room.participants.len(),
 551                                "refreshed room"
 552                            );
 553                            room_updated(&refreshed_room.room, &peer);
 554                            if let Some(channel) = refreshed_room.channel.as_ref() {
 555                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 556                            }
 557                            contacts_to_update
 558                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 559                            contacts_to_update
 560                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 561                            canceled_calls_to_user_ids =
 562                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 563                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 564                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 565                        }
 566
 567                        {
 568                            let pool = pool.lock();
 569                            for canceled_user_id in canceled_calls_to_user_ids {
 570                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 571                                    peer.send(
 572                                        connection_id,
 573                                        proto::CallCanceled {
 574                                            room_id: room_id.to_proto(),
 575                                        },
 576                                    )
 577                                    .trace_err();
 578                                }
 579                            }
 580                        }
 581
 582                        for user_id in contacts_to_update {
 583                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 584                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 585                            if let Some((busy, contacts)) = busy.zip(contacts) {
 586                                let pool = pool.lock();
 587                                let updated_contact = contact_for_user(user_id, busy, &pool);
 588                                for contact in contacts {
 589                                    if let db::Contact::Accepted {
 590                                        user_id: contact_user_id,
 591                                        ..
 592                                    } = contact
 593                                    {
 594                                        for contact_conn_id in
 595                                            pool.user_connection_ids(contact_user_id)
 596                                        {
 597                                            peer.send(
 598                                                contact_conn_id,
 599                                                proto::UpdateContacts {
 600                                                    contacts: vec![updated_contact.clone()],
 601                                                    remove_contacts: Default::default(),
 602                                                    incoming_requests: Default::default(),
 603                                                    remove_incoming_requests: Default::default(),
 604                                                    outgoing_requests: Default::default(),
 605                                                    remove_outgoing_requests: Default::default(),
 606                                                },
 607                                            )
 608                                            .trace_err();
 609                                        }
 610                                    }
 611                                }
 612                            }
 613                        }
 614
 615                        if let Some(live_kit) = livekit_client.as_ref()
 616                            && delete_livekit_room
 617                        {
 618                            live_kit.delete_room(livekit_room).await.trace_err();
 619                        }
 620                    }
 621                }
 622
 623                app_state
 624                    .db
 625                    .delete_stale_channel_chat_participants(
 626                        &app_state.config.zed_environment,
 627                        server_id,
 628                    )
 629                    .await
 630                    .trace_err();
 631
 632                app_state
 633                    .db
 634                    .clear_old_worktree_entries(server_id)
 635                    .await
 636                    .trace_err();
 637
 638                app_state
 639                    .db
 640                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 641                    .await
 642                    .trace_err();
 643            }
 644            .instrument(span),
 645        );
 646        Ok(())
 647    }
 648
 649    pub fn teardown(&self) {
 650        self.peer.teardown();
 651        self.connection_pool.lock().reset();
 652        let _ = self.teardown.send(true);
 653    }
 654
 655    #[cfg(feature = "test-support")]
 656    pub fn reset(&self, id: ServerId) {
 657        self.teardown();
 658        *self.id.lock() = id;
 659        self.peer.reset(id.0 as u32);
 660        let _ = self.teardown.send(false);
 661    }
 662
 663    #[cfg(feature = "test-support")]
 664    pub fn id(&self) -> ServerId {
 665        *self.id.lock()
 666    }
 667
 668    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 669    where
 670        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 671        Fut: 'static + Send + Future<Output = Result<()>>,
 672        M: EnvelopedMessage,
 673    {
 674        let prev_handler = self.handlers.insert(
 675            TypeId::of::<M>(),
 676            Box::new(move |envelope, session, span| {
 677                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 678                let received_at = envelope.received_at;
 679                tracing::info!("message received");
 680                let start_time = Instant::now();
 681                let future = (handler)(
 682                    *envelope,
 683                    MessageContext {
 684                        session,
 685                        span: span.clone(),
 686                    },
 687                );
 688                async move {
 689                    let result = future.await;
 690                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 691                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 692                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 693                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 694                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 695                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 696                    match result {
 697                        Err(error) => {
 698                            tracing::error!(?error, "error handling message")
 699                        }
 700                        Ok(()) => tracing::info!("finished handling message"),
 701                    }
 702                }
 703                .boxed()
 704            }),
 705        );
 706        if prev_handler.is_some() {
 707            panic!("registered a handler for the same message twice");
 708        }
 709        self
 710    }
 711
 712    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 713    where
 714        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 715        Fut: 'static + Send + Future<Output = Result<()>>,
 716        M: EnvelopedMessage,
 717    {
 718        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 719        self
 720    }
 721
 722    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 723    where
 724        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 725        Fut: Send + Future<Output = Result<()>>,
 726        M: RequestMessage,
 727    {
 728        let handler = Arc::new(handler);
 729        self.add_handler(move |envelope, session| {
 730            let receipt = envelope.receipt();
 731            let handler = handler.clone();
 732            async move {
 733                let peer = session.peer.clone();
 734                let responded = Arc::new(AtomicBool::default());
 735                let response = Response {
 736                    peer: peer.clone(),
 737                    responded: responded.clone(),
 738                    receipt,
 739                };
 740                match (handler)(envelope.payload, response, session).await {
 741                    Ok(()) => {
 742                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 743                            Ok(())
 744                        } else {
 745                            Err(anyhow!("handler did not send a response"))?
 746                        }
 747                    }
 748                    Err(error) => {
 749                        let proto_err = match &error {
 750                            Error::Internal(err) => err.to_proto(),
 751                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 752                        };
 753                        peer.respond_with_error(receipt, proto_err)?;
 754                        Err(error)
 755                    }
 756                }
 757            }
 758        })
 759    }
 760
 761    pub fn handle_connection(
 762        self: &Arc<Self>,
 763        connection: Connection,
 764        address: String,
 765        principal: Principal,
 766        zed_version: ZedVersion,
 767        release_channel: Option<String>,
 768        user_agent: Option<String>,
 769        geoip_country_code: Option<String>,
 770        system_id: Option<String>,
 771        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 772        executor: Executor,
 773        connection_guard: Option<ConnectionGuard>,
 774    ) -> impl Future<Output = ()> + use<> {
 775        let this = self.clone();
 776        let span = info_span!("handle connection", %address,
 777            connection_id=field::Empty,
 778            user_id=field::Empty,
 779            login=field::Empty,
 780            impersonator=field::Empty,
 781            user_agent=field::Empty,
 782            geoip_country_code=field::Empty,
 783            release_channel=field::Empty,
 784        );
 785        principal.update_span(&span);
 786        if let Some(user_agent) = user_agent {
 787            span.record("user_agent", user_agent);
 788        }
 789        if let Some(release_channel) = release_channel {
 790            span.record("release_channel", release_channel);
 791        }
 792
 793        if let Some(country_code) = geoip_country_code.as_ref() {
 794            span.record("geoip_country_code", country_code);
 795        }
 796
 797        let mut teardown = self.teardown.subscribe();
 798        async move {
 799            if *teardown.borrow() {
 800                tracing::error!("server is tearing down");
 801                return;
 802            }
 803
 804            let (connection_id, handle_io, mut incoming_rx) =
 805                this.peer.add_connection(connection, {
 806                    let executor = executor.clone();
 807                    move |duration| executor.sleep(duration)
 808                });
 809            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 810
 811            tracing::info!("connection opened");
 812
 813            let session = Session {
 814                principal: principal.clone(),
 815                connection_id,
 816                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 817                peer: this.peer.clone(),
 818                connection_pool: this.connection_pool.clone(),
 819                app_state: this.app_state.clone(),
 820                geoip_country_code,
 821                system_id,
 822                _executor: executor.clone(),
 823            };
 824
 825            if let Err(error) = this
 826                .send_initial_client_update(
 827                    connection_id,
 828                    zed_version,
 829                    send_connection_id,
 830                    &session,
 831                )
 832                .await
 833            {
 834                tracing::error!(?error, "failed to send initial client update");
 835                return;
 836            }
 837            drop(connection_guard);
 838
 839            let handle_io = handle_io.fuse();
 840            futures::pin_mut!(handle_io);
 841
 842            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 843            // This prevents deadlocks when e.g., client A performs a request to client B and
 844            // client B performs a request to client A. If both clients stop processing further
 845            // messages until their respective request completes, they won't have a chance to
 846            // respond to the other client's request and cause a deadlock.
 847            //
 848            // This arrangement ensures we will attempt to process earlier messages first, but fall
 849            // back to processing messages arrived later in the spirit of making progress.
 850            const MAX_CONCURRENT_HANDLERS: usize = 256;
 851            let mut foreground_message_handlers = FuturesUnordered::new();
 852            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 853            let get_concurrent_handlers = {
 854                let concurrent_handlers = concurrent_handlers.clone();
 855                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 856            };
 857            loop {
 858                let next_message = async {
 859                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 860                    let message = incoming_rx.next().await;
 861                    // Cache the concurrent_handlers here, so that we know what the
 862                    // queue looks like as each handler starts
 863                    (permit, message, get_concurrent_handlers())
 864                }
 865                .fuse();
 866                futures::pin_mut!(next_message);
 867                futures::select_biased! {
 868                    _ = teardown.changed().fuse() => return,
 869                    result = handle_io => {
 870                        if let Err(error) = result {
 871                            tracing::error!(?error, "error handling I/O");
 872                        }
 873                        break;
 874                    }
 875                    _ = foreground_message_handlers.next() => {}
 876                    next_message = next_message => {
 877                        let (permit, message, concurrent_handlers) = next_message;
 878                        if let Some(message) = message {
 879                            let type_name = message.payload_type_name();
 880                            // note: we copy all the fields from the parent span so we can query them in the logs.
 881                            // (https://github.com/tokio-rs/tracing/issues/2670).
 882                            let span = tracing::info_span!("receive message",
 883                                %connection_id,
 884                                %address,
 885                                type_name,
 886                                concurrent_handlers,
 887                                user_id=field::Empty,
 888                                login=field::Empty,
 889                                impersonator=field::Empty,
 890                                lsp_query_request=field::Empty,
 891                                release_channel=field::Empty,
 892                                { TOTAL_DURATION_MS }=field::Empty,
 893                                { PROCESSING_DURATION_MS }=field::Empty,
 894                                { QUEUE_DURATION_MS }=field::Empty,
 895                                { HOST_WAITING_MS }=field::Empty
 896                            );
 897                            principal.update_span(&span);
 898                            let span_enter = span.enter();
 899                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 900                                let is_background = message.is_background();
 901                                let handle_message = (handler)(message, session.clone(), span.clone());
 902                                drop(span_enter);
 903
 904                                let handle_message = async move {
 905                                    handle_message.await;
 906                                    drop(permit);
 907                                }.instrument(span);
 908                                if is_background {
 909                                    executor.spawn_detached(handle_message);
 910                                } else {
 911                                    foreground_message_handlers.push(handle_message);
 912                                }
 913                            } else {
 914                                tracing::error!("no message handler");
 915                            }
 916                        } else {
 917                            tracing::info!("connection closed");
 918                            break;
 919                        }
 920                    }
 921                }
 922            }
 923
 924            drop(foreground_message_handlers);
 925            let concurrent_handlers = get_concurrent_handlers();
 926            tracing::info!(concurrent_handlers, "signing out");
 927            if let Err(error) = connection_lost(session, teardown, executor).await {
 928                tracing::error!(?error, "error signing out");
 929            }
 930        }
 931        .instrument(span)
 932    }
 933
 934    async fn send_initial_client_update(
 935        &self,
 936        connection_id: ConnectionId,
 937        zed_version: ZedVersion,
 938        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 939        session: &Session,
 940    ) -> Result<()> {
 941        self.peer.send(
 942            connection_id,
 943            proto::Hello {
 944                peer_id: Some(connection_id.into()),
 945            },
 946        )?;
 947        tracing::info!("sent hello message");
 948        if let Some(send_connection_id) = send_connection_id.take() {
 949            let _ = send_connection_id.send(connection_id);
 950        }
 951
 952        match &session.principal {
 953            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 954                if !user.connected_once {
 955                    self.peer.send(connection_id, proto::ShowContacts {})?;
 956                    self.app_state
 957                        .db
 958                        .set_user_connected_once(user.id, true)
 959                        .await?;
 960                }
 961
 962                let contacts = self.app_state.db.get_contacts(user.id).await?;
 963
 964                {
 965                    let mut pool = self.connection_pool.lock();
 966                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 967                    self.peer.send(
 968                        connection_id,
 969                        build_initial_contacts_update(contacts, &pool),
 970                    )?;
 971                }
 972
 973                if should_auto_subscribe_to_channels(&zed_version) {
 974                    subscribe_user_to_channels(user.id, session).await?;
 975                }
 976
 977                if let Some(incoming_call) =
 978                    self.app_state.db.incoming_call_for_user(user.id).await?
 979                {
 980                    self.peer.send(connection_id, incoming_call)?;
 981                }
 982
 983                update_user_contacts(user.id, session).await?;
 984            }
 985        }
 986
 987        Ok(())
 988    }
 989
 990    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
 991        ServerSnapshot {
 992            connection_pool: ConnectionPoolGuard {
 993                guard: self.connection_pool.lock(),
 994                _not_send: PhantomData,
 995            },
 996            peer: &self.peer,
 997        }
 998    }
 999}
1000
1001impl Deref for ConnectionPoolGuard<'_> {
1002    type Target = ConnectionPool;
1003
1004    fn deref(&self) -> &Self::Target {
1005        &self.guard
1006    }
1007}
1008
1009impl DerefMut for ConnectionPoolGuard<'_> {
1010    fn deref_mut(&mut self) -> &mut Self::Target {
1011        &mut self.guard
1012    }
1013}
1014
1015impl Drop for ConnectionPoolGuard<'_> {
1016    fn drop(&mut self) {
1017        #[cfg(feature = "test-support")]
1018        self.check_invariants();
1019    }
1020}
1021
1022fn broadcast<F>(
1023    sender_id: Option<ConnectionId>,
1024    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1025    mut f: F,
1026) where
1027    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1028{
1029    for receiver_id in receiver_ids {
1030        if Some(receiver_id) != sender_id
1031            && let Err(error) = f(receiver_id)
1032        {
1033            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1034        }
1035    }
1036}
1037
1038pub struct ProtocolVersion(u32);
1039
1040impl Header for ProtocolVersion {
1041    fn name() -> &'static HeaderName {
1042        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1043        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1044    }
1045
1046    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1047    where
1048        Self: Sized,
1049        I: Iterator<Item = &'i axum::http::HeaderValue>,
1050    {
1051        let version = values
1052            .next()
1053            .ok_or_else(axum::headers::Error::invalid)?
1054            .to_str()
1055            .map_err(|_| axum::headers::Error::invalid())?
1056            .parse()
1057            .map_err(|_| axum::headers::Error::invalid())?;
1058        Ok(Self(version))
1059    }
1060
1061    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1062        values.extend([self.0.to_string().parse().unwrap()]);
1063    }
1064}
1065
1066pub struct AppVersionHeader(Version);
1067impl Header for AppVersionHeader {
1068    fn name() -> &'static HeaderName {
1069        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1070        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1071    }
1072
1073    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1074    where
1075        Self: Sized,
1076        I: Iterator<Item = &'i axum::http::HeaderValue>,
1077    {
1078        let version = values
1079            .next()
1080            .ok_or_else(axum::headers::Error::invalid)?
1081            .to_str()
1082            .map_err(|_| axum::headers::Error::invalid())?
1083            .parse()
1084            .map_err(|_| axum::headers::Error::invalid())?;
1085        Ok(Self(version))
1086    }
1087
1088    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1089        values.extend([self.0.to_string().parse().unwrap()]);
1090    }
1091}
1092
1093#[derive(Debug)]
1094pub struct ReleaseChannelHeader(String);
1095
1096impl Header for ReleaseChannelHeader {
1097    fn name() -> &'static HeaderName {
1098        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1099        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1100    }
1101
1102    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1103    where
1104        Self: Sized,
1105        I: Iterator<Item = &'i axum::http::HeaderValue>,
1106    {
1107        Ok(Self(
1108            values
1109                .next()
1110                .ok_or_else(axum::headers::Error::invalid)?
1111                .to_str()
1112                .map_err(|_| axum::headers::Error::invalid())?
1113                .to_owned(),
1114        ))
1115    }
1116
1117    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1118        values.extend([self.0.parse().unwrap()]);
1119    }
1120}
1121
1122pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1123    Router::new()
1124        .route("/rpc", get(handle_websocket_request))
1125        .layer(
1126            ServiceBuilder::new()
1127                .layer(Extension(server.app_state.clone()))
1128                .layer(middleware::from_fn(auth::validate_header)),
1129        )
1130        .route("/metrics", get(handle_metrics))
1131        .layer(Extension(server))
1132}
1133
1134pub async fn handle_websocket_request(
1135    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1136    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1137    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1138    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1139    Extension(server): Extension<Arc<Server>>,
1140    Extension(principal): Extension<Principal>,
1141    user_agent: Option<TypedHeader<UserAgent>>,
1142    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1143    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1144    ws: WebSocketUpgrade,
1145) -> axum::response::Response {
1146    if protocol_version != rpc::PROTOCOL_VERSION {
1147        return (
1148            StatusCode::UPGRADE_REQUIRED,
1149            "client must be upgraded".to_string(),
1150        )
1151            .into_response();
1152    }
1153
1154    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1155        return (
1156            StatusCode::UPGRADE_REQUIRED,
1157            "no version header found".to_string(),
1158        )
1159            .into_response();
1160    };
1161
1162    let release_channel = release_channel_header.map(|header| header.0.0);
1163
1164    if !version.can_collaborate() {
1165        return (
1166            StatusCode::UPGRADE_REQUIRED,
1167            "client must be upgraded".to_string(),
1168        )
1169            .into_response();
1170    }
1171
1172    let socket_address = socket_address.to_string();
1173
1174    // Acquire connection guard before WebSocket upgrade
1175    let connection_guard = match ConnectionGuard::try_acquire() {
1176        Ok(guard) => guard,
1177        Err(()) => {
1178            return (
1179                StatusCode::SERVICE_UNAVAILABLE,
1180                "Too many concurrent connections",
1181            )
1182                .into_response();
1183        }
1184    };
1185
1186    ws.on_upgrade(move |socket| {
1187        let socket = socket
1188            .map_ok(to_tungstenite_message)
1189            .err_into()
1190            .with(|message| async move { to_axum_message(message) });
1191        let connection = Connection::new(Box::pin(socket));
1192        async move {
1193            server
1194                .handle_connection(
1195                    connection,
1196                    socket_address,
1197                    principal,
1198                    version,
1199                    release_channel,
1200                    user_agent.map(|header| header.to_string()),
1201                    country_code_header.map(|header| header.to_string()),
1202                    system_id_header.map(|header| header.to_string()),
1203                    None,
1204                    Executor::Production,
1205                    Some(connection_guard),
1206                )
1207                .await;
1208        }
1209    })
1210}
1211
1212pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1213    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1214    let connections_metric = CONNECTIONS_METRIC
1215        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1216
1217    let connections = server
1218        .connection_pool
1219        .lock()
1220        .connections()
1221        .filter(|connection| !connection.admin)
1222        .count();
1223    connections_metric.set(connections as _);
1224
1225    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1226    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1227        register_int_gauge!(
1228            "shared_projects",
1229            "number of open projects with one or more guests"
1230        )
1231        .unwrap()
1232    });
1233
1234    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1235    shared_projects_metric.set(shared_projects as _);
1236
1237    let encoder = prometheus::TextEncoder::new();
1238    let metric_families = prometheus::gather();
1239    let encoded_metrics = encoder
1240        .encode_to_string(&metric_families)
1241        .map_err(|err| anyhow!("{err}"))?;
1242    Ok(encoded_metrics)
1243}
1244
1245#[instrument(err, skip(executor))]
1246async fn connection_lost(
1247    session: Session,
1248    mut teardown: watch::Receiver<bool>,
1249    executor: Executor,
1250) -> Result<()> {
1251    session.peer.disconnect(session.connection_id);
1252    session
1253        .connection_pool()
1254        .await
1255        .remove_connection(session.connection_id)?;
1256
1257    session
1258        .db()
1259        .await
1260        .connection_lost(session.connection_id)
1261        .await
1262        .trace_err();
1263
1264    futures::select_biased! {
1265        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1266
1267            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1268            leave_room_for_session(&session, session.connection_id).await.trace_err();
1269            leave_channel_buffers_for_session(&session)
1270                .await
1271                .trace_err();
1272
1273            if !session
1274                .connection_pool()
1275                .await
1276                .is_user_online(session.user_id())
1277            {
1278                let db = session.db().await;
1279                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1280                    room_updated(&room, &session.peer);
1281                }
1282            }
1283
1284            update_user_contacts(session.user_id(), &session).await?;
1285        },
1286        _ = teardown.changed().fuse() => {}
1287    }
1288
1289    Ok(())
1290}
1291
1292/// Acknowledges a ping from a client, used to keep the connection alive.
1293async fn ping(
1294    _: proto::Ping,
1295    response: Response<proto::Ping>,
1296    _session: MessageContext,
1297) -> Result<()> {
1298    response.send(proto::Ack {})?;
1299    Ok(())
1300}
1301
1302/// Creates a new room for calling (outside of channels)
1303async fn create_room(
1304    _request: proto::CreateRoom,
1305    response: Response<proto::CreateRoom>,
1306    session: MessageContext,
1307) -> Result<()> {
1308    let livekit_room = nanoid::nanoid!(30);
1309
1310    let live_kit_connection_info = util::maybe!(async {
1311        let live_kit = session.app_state.livekit_client.as_ref();
1312        let live_kit = live_kit?;
1313        let user_id = session.user_id().to_string();
1314
1315        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1316
1317        Some(proto::LiveKitConnectionInfo {
1318            server_url: live_kit.url().into(),
1319            token,
1320            can_publish: true,
1321        })
1322    })
1323    .await;
1324
1325    let room = session
1326        .db()
1327        .await
1328        .create_room(session.user_id(), session.connection_id, &livekit_room)
1329        .await?;
1330
1331    response.send(proto::CreateRoomResponse {
1332        room: Some(room.clone()),
1333        live_kit_connection_info,
1334    })?;
1335
1336    update_user_contacts(session.user_id(), &session).await?;
1337    Ok(())
1338}
1339
1340/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1341async fn join_room(
1342    request: proto::JoinRoom,
1343    response: Response<proto::JoinRoom>,
1344    session: MessageContext,
1345) -> Result<()> {
1346    let room_id = RoomId::from_proto(request.id);
1347
1348    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1349
1350    if let Some(channel_id) = channel_id {
1351        return join_channel_internal(channel_id, Box::new(response), session).await;
1352    }
1353
1354    let joined_room = {
1355        let room = session
1356            .db()
1357            .await
1358            .join_room(room_id, session.user_id(), session.connection_id)
1359            .await?;
1360        room_updated(&room.room, &session.peer);
1361        room.into_inner()
1362    };
1363
1364    for connection_id in session
1365        .connection_pool()
1366        .await
1367        .user_connection_ids(session.user_id())
1368    {
1369        session
1370            .peer
1371            .send(
1372                connection_id,
1373                proto::CallCanceled {
1374                    room_id: room_id.to_proto(),
1375                },
1376            )
1377            .trace_err();
1378    }
1379
1380    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1381    {
1382        live_kit
1383            .room_token(
1384                &joined_room.room.livekit_room,
1385                &session.user_id().to_string(),
1386            )
1387            .trace_err()
1388            .map(|token| proto::LiveKitConnectionInfo {
1389                server_url: live_kit.url().into(),
1390                token,
1391                can_publish: true,
1392            })
1393    } else {
1394        None
1395    };
1396
1397    response.send(proto::JoinRoomResponse {
1398        room: Some(joined_room.room),
1399        channel_id: None,
1400        live_kit_connection_info,
1401    })?;
1402
1403    update_user_contacts(session.user_id(), &session).await?;
1404    Ok(())
1405}
1406
1407/// Rejoin room is used to reconnect to a room after connection errors.
1408async fn rejoin_room(
1409    request: proto::RejoinRoom,
1410    response: Response<proto::RejoinRoom>,
1411    session: MessageContext,
1412) -> Result<()> {
1413    let room;
1414    let channel;
1415    {
1416        let mut rejoined_room = session
1417            .db()
1418            .await
1419            .rejoin_room(request, session.user_id(), session.connection_id)
1420            .await?;
1421
1422        response.send(proto::RejoinRoomResponse {
1423            room: Some(rejoined_room.room.clone()),
1424            reshared_projects: rejoined_room
1425                .reshared_projects
1426                .iter()
1427                .map(|project| proto::ResharedProject {
1428                    id: project.id.to_proto(),
1429                    collaborators: project
1430                        .collaborators
1431                        .iter()
1432                        .map(|collaborator| collaborator.to_proto())
1433                        .collect(),
1434                })
1435                .collect(),
1436            rejoined_projects: rejoined_room
1437                .rejoined_projects
1438                .iter()
1439                .map(|rejoined_project| rejoined_project.to_proto())
1440                .collect(),
1441        })?;
1442        room_updated(&rejoined_room.room, &session.peer);
1443
1444        for project in &rejoined_room.reshared_projects {
1445            for collaborator in &project.collaborators {
1446                session
1447                    .peer
1448                    .send(
1449                        collaborator.connection_id,
1450                        proto::UpdateProjectCollaborator {
1451                            project_id: project.id.to_proto(),
1452                            old_peer_id: Some(project.old_connection_id.into()),
1453                            new_peer_id: Some(session.connection_id.into()),
1454                        },
1455                    )
1456                    .trace_err();
1457            }
1458
1459            broadcast(
1460                Some(session.connection_id),
1461                project
1462                    .collaborators
1463                    .iter()
1464                    .map(|collaborator| collaborator.connection_id),
1465                |connection_id| {
1466                    session.peer.forward_send(
1467                        session.connection_id,
1468                        connection_id,
1469                        proto::UpdateProject {
1470                            project_id: project.id.to_proto(),
1471                            worktrees: project.worktrees.clone(),
1472                        },
1473                    )
1474                },
1475            );
1476        }
1477
1478        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1479
1480        let rejoined_room = rejoined_room.into_inner();
1481
1482        room = rejoined_room.room;
1483        channel = rejoined_room.channel;
1484    }
1485
1486    if let Some(channel) = channel {
1487        channel_updated(
1488            &channel,
1489            &room,
1490            &session.peer,
1491            &*session.connection_pool().await,
1492        );
1493    }
1494
1495    update_user_contacts(session.user_id(), &session).await?;
1496    Ok(())
1497}
1498
1499fn notify_rejoined_projects(
1500    rejoined_projects: &mut Vec<RejoinedProject>,
1501    session: &Session,
1502) -> Result<()> {
1503    for project in rejoined_projects.iter() {
1504        for collaborator in &project.collaborators {
1505            session
1506                .peer
1507                .send(
1508                    collaborator.connection_id,
1509                    proto::UpdateProjectCollaborator {
1510                        project_id: project.id.to_proto(),
1511                        old_peer_id: Some(project.old_connection_id.into()),
1512                        new_peer_id: Some(session.connection_id.into()),
1513                    },
1514                )
1515                .trace_err();
1516        }
1517    }
1518
1519    for project in rejoined_projects {
1520        for worktree in mem::take(&mut project.worktrees) {
1521            // Stream this worktree's entries.
1522            let message = proto::UpdateWorktree {
1523                project_id: project.id.to_proto(),
1524                worktree_id: worktree.id,
1525                abs_path: worktree.abs_path.clone(),
1526                root_name: worktree.root_name,
1527                updated_entries: worktree.updated_entries,
1528                removed_entries: worktree.removed_entries,
1529                scan_id: worktree.scan_id,
1530                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1531                updated_repositories: worktree.updated_repositories,
1532                removed_repositories: worktree.removed_repositories,
1533            };
1534            for update in proto::split_worktree_update(message) {
1535                session.peer.send(session.connection_id, update)?;
1536            }
1537
1538            // Stream this worktree's diagnostics.
1539            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1540            if let Some(summary) = worktree_diagnostics.next() {
1541                let message = proto::UpdateDiagnosticSummary {
1542                    project_id: project.id.to_proto(),
1543                    worktree_id: worktree.id,
1544                    summary: Some(summary),
1545                    more_summaries: worktree_diagnostics.collect(),
1546                };
1547                session.peer.send(session.connection_id, message)?;
1548            }
1549
1550            for settings_file in worktree.settings_files {
1551                session.peer.send(
1552                    session.connection_id,
1553                    proto::UpdateWorktreeSettings {
1554                        project_id: project.id.to_proto(),
1555                        worktree_id: worktree.id,
1556                        path: settings_file.path,
1557                        content: Some(settings_file.content),
1558                        kind: Some(settings_file.kind.to_proto().into()),
1559                        outside_worktree: Some(settings_file.outside_worktree),
1560                    },
1561                )?;
1562            }
1563        }
1564
1565        for repository in mem::take(&mut project.updated_repositories) {
1566            for update in split_repository_update(repository) {
1567                session.peer.send(session.connection_id, update)?;
1568            }
1569        }
1570
1571        for id in mem::take(&mut project.removed_repositories) {
1572            session.peer.send(
1573                session.connection_id,
1574                proto::RemoveRepository {
1575                    project_id: project.id.to_proto(),
1576                    id,
1577                },
1578            )?;
1579        }
1580    }
1581
1582    Ok(())
1583}
1584
1585/// leave room disconnects from the room.
1586async fn leave_room(
1587    _: proto::LeaveRoom,
1588    response: Response<proto::LeaveRoom>,
1589    session: MessageContext,
1590) -> Result<()> {
1591    leave_room_for_session(&session, session.connection_id).await?;
1592    response.send(proto::Ack {})?;
1593    Ok(())
1594}
1595
1596/// Updates the permissions of someone else in the room.
1597async fn set_room_participant_role(
1598    request: proto::SetRoomParticipantRole,
1599    response: Response<proto::SetRoomParticipantRole>,
1600    session: MessageContext,
1601) -> Result<()> {
1602    let user_id = UserId::from_proto(request.user_id);
1603    let role = ChannelRole::from(request.role());
1604
1605    let (livekit_room, can_publish) = {
1606        let room = session
1607            .db()
1608            .await
1609            .set_room_participant_role(
1610                session.user_id(),
1611                RoomId::from_proto(request.room_id),
1612                user_id,
1613                role,
1614            )
1615            .await?;
1616
1617        let livekit_room = room.livekit_room.clone();
1618        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1619        room_updated(&room, &session.peer);
1620        (livekit_room, can_publish)
1621    };
1622
1623    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1624        live_kit
1625            .update_participant(
1626                livekit_room.clone(),
1627                request.user_id.to_string(),
1628                livekit_api::proto::ParticipantPermission {
1629                    can_subscribe: true,
1630                    can_publish,
1631                    can_publish_data: can_publish,
1632                    hidden: false,
1633                    recorder: false,
1634                },
1635            )
1636            .await
1637            .trace_err();
1638    }
1639
1640    response.send(proto::Ack {})?;
1641    Ok(())
1642}
1643
1644/// Call someone else into the current room
1645async fn call(
1646    request: proto::Call,
1647    response: Response<proto::Call>,
1648    session: MessageContext,
1649) -> Result<()> {
1650    let room_id = RoomId::from_proto(request.room_id);
1651    let calling_user_id = session.user_id();
1652    let calling_connection_id = session.connection_id;
1653    let called_user_id = UserId::from_proto(request.called_user_id);
1654    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1655    if !session
1656        .db()
1657        .await
1658        .has_contact(calling_user_id, called_user_id)
1659        .await?
1660    {
1661        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1662    }
1663
1664    let incoming_call = {
1665        let (room, incoming_call) = &mut *session
1666            .db()
1667            .await
1668            .call(
1669                room_id,
1670                calling_user_id,
1671                calling_connection_id,
1672                called_user_id,
1673                initial_project_id,
1674            )
1675            .await?;
1676        room_updated(room, &session.peer);
1677        mem::take(incoming_call)
1678    };
1679    update_user_contacts(called_user_id, &session).await?;
1680
1681    let mut calls = session
1682        .connection_pool()
1683        .await
1684        .user_connection_ids(called_user_id)
1685        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1686        .collect::<FuturesUnordered<_>>();
1687
1688    while let Some(call_response) = calls.next().await {
1689        match call_response.as_ref() {
1690            Ok(_) => {
1691                response.send(proto::Ack {})?;
1692                return Ok(());
1693            }
1694            Err(_) => {
1695                call_response.trace_err();
1696            }
1697        }
1698    }
1699
1700    {
1701        let room = session
1702            .db()
1703            .await
1704            .call_failed(room_id, called_user_id)
1705            .await?;
1706        room_updated(&room, &session.peer);
1707    }
1708    update_user_contacts(called_user_id, &session).await?;
1709
1710    Err(anyhow!("failed to ring user"))?
1711}
1712
1713/// Cancel an outgoing call.
1714async fn cancel_call(
1715    request: proto::CancelCall,
1716    response: Response<proto::CancelCall>,
1717    session: MessageContext,
1718) -> Result<()> {
1719    let called_user_id = UserId::from_proto(request.called_user_id);
1720    let room_id = RoomId::from_proto(request.room_id);
1721    {
1722        let room = session
1723            .db()
1724            .await
1725            .cancel_call(room_id, session.connection_id, called_user_id)
1726            .await?;
1727        room_updated(&room, &session.peer);
1728    }
1729
1730    for connection_id in session
1731        .connection_pool()
1732        .await
1733        .user_connection_ids(called_user_id)
1734    {
1735        session
1736            .peer
1737            .send(
1738                connection_id,
1739                proto::CallCanceled {
1740                    room_id: room_id.to_proto(),
1741                },
1742            )
1743            .trace_err();
1744    }
1745    response.send(proto::Ack {})?;
1746
1747    update_user_contacts(called_user_id, &session).await?;
1748    Ok(())
1749}
1750
1751/// Decline an incoming call.
1752async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1753    let room_id = RoomId::from_proto(message.room_id);
1754    {
1755        let room = session
1756            .db()
1757            .await
1758            .decline_call(Some(room_id), session.user_id())
1759            .await?
1760            .context("declining call")?;
1761        room_updated(&room, &session.peer);
1762    }
1763
1764    for connection_id in session
1765        .connection_pool()
1766        .await
1767        .user_connection_ids(session.user_id())
1768    {
1769        session
1770            .peer
1771            .send(
1772                connection_id,
1773                proto::CallCanceled {
1774                    room_id: room_id.to_proto(),
1775                },
1776            )
1777            .trace_err();
1778    }
1779    update_user_contacts(session.user_id(), &session).await?;
1780    Ok(())
1781}
1782
1783/// Updates other participants in the room with your current location.
1784async fn update_participant_location(
1785    request: proto::UpdateParticipantLocation,
1786    response: Response<proto::UpdateParticipantLocation>,
1787    session: MessageContext,
1788) -> Result<()> {
1789    let room_id = RoomId::from_proto(request.room_id);
1790    let location = request.location.context("invalid location")?;
1791
1792    let db = session.db().await;
1793    let room = db
1794        .update_room_participant_location(room_id, session.connection_id, location)
1795        .await?;
1796
1797    room_updated(&room, &session.peer);
1798    response.send(proto::Ack {})?;
1799    Ok(())
1800}
1801
1802/// Share a project into the room.
1803async fn share_project(
1804    request: proto::ShareProject,
1805    response: Response<proto::ShareProject>,
1806    session: MessageContext,
1807) -> Result<()> {
1808    let (project_id, room) = &*session
1809        .db()
1810        .await
1811        .share_project(
1812            RoomId::from_proto(request.room_id),
1813            session.connection_id,
1814            &request.worktrees,
1815            request.is_ssh_project,
1816            request.windows_paths.unwrap_or(false),
1817        )
1818        .await?;
1819    response.send(proto::ShareProjectResponse {
1820        project_id: project_id.to_proto(),
1821    })?;
1822    room_updated(room, &session.peer);
1823
1824    Ok(())
1825}
1826
1827/// Unshare a project from the room.
1828async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1829    let project_id = ProjectId::from_proto(message.project_id);
1830    unshare_project_internal(project_id, session.connection_id, &session).await
1831}
1832
1833async fn unshare_project_internal(
1834    project_id: ProjectId,
1835    connection_id: ConnectionId,
1836    session: &Session,
1837) -> Result<()> {
1838    let delete = {
1839        let room_guard = session
1840            .db()
1841            .await
1842            .unshare_project(project_id, connection_id)
1843            .await?;
1844
1845        let (delete, room, guest_connection_ids) = &*room_guard;
1846
1847        let message = proto::UnshareProject {
1848            project_id: project_id.to_proto(),
1849        };
1850
1851        broadcast(
1852            Some(connection_id),
1853            guest_connection_ids.iter().copied(),
1854            |conn_id| session.peer.send(conn_id, message.clone()),
1855        );
1856        if let Some(room) = room {
1857            room_updated(room, &session.peer);
1858        }
1859
1860        *delete
1861    };
1862
1863    if delete {
1864        let db = session.db().await;
1865        db.delete_project(project_id).await?;
1866    }
1867
1868    Ok(())
1869}
1870
1871/// Join someone elses shared project.
1872async fn join_project(
1873    request: proto::JoinProject,
1874    response: Response<proto::JoinProject>,
1875    session: MessageContext,
1876) -> Result<()> {
1877    let project_id = ProjectId::from_proto(request.project_id);
1878
1879    tracing::info!(%project_id, "join project");
1880
1881    let db = session.db().await;
1882    let (project, replica_id) = &mut *db
1883        .join_project(
1884            project_id,
1885            session.connection_id,
1886            session.user_id(),
1887            request.committer_name.clone(),
1888            request.committer_email.clone(),
1889        )
1890        .await?;
1891    drop(db);
1892    tracing::info!(%project_id, "join remote project");
1893    let collaborators = project
1894        .collaborators
1895        .iter()
1896        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1897        .map(|collaborator| collaborator.to_proto())
1898        .collect::<Vec<_>>();
1899    let project_id = project.id;
1900    let guest_user_id = session.user_id();
1901
1902    let worktrees = project
1903        .worktrees
1904        .iter()
1905        .map(|(id, worktree)| proto::WorktreeMetadata {
1906            id: *id,
1907            root_name: worktree.root_name.clone(),
1908            visible: worktree.visible,
1909            abs_path: worktree.abs_path.clone(),
1910        })
1911        .collect::<Vec<_>>();
1912
1913    let add_project_collaborator = proto::AddProjectCollaborator {
1914        project_id: project_id.to_proto(),
1915        collaborator: Some(proto::Collaborator {
1916            peer_id: Some(session.connection_id.into()),
1917            replica_id: replica_id.0 as u32,
1918            user_id: guest_user_id.to_proto(),
1919            is_host: false,
1920            committer_name: request.committer_name.clone(),
1921            committer_email: request.committer_email.clone(),
1922        }),
1923    };
1924
1925    for collaborator in &collaborators {
1926        session
1927            .peer
1928            .send(
1929                collaborator.peer_id.unwrap().into(),
1930                add_project_collaborator.clone(),
1931            )
1932            .trace_err();
1933    }
1934
1935    // First, we send the metadata associated with each worktree.
1936    let (language_servers, language_server_capabilities) = project
1937        .language_servers
1938        .clone()
1939        .into_iter()
1940        .map(|server| (server.server, server.capabilities))
1941        .unzip();
1942    response.send(proto::JoinProjectResponse {
1943        project_id: project.id.0 as u64,
1944        worktrees,
1945        replica_id: replica_id.0 as u32,
1946        collaborators,
1947        language_servers,
1948        language_server_capabilities,
1949        role: project.role.into(),
1950        windows_paths: project.path_style == PathStyle::Windows,
1951    })?;
1952
1953    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1954        // Stream this worktree's entries.
1955        let message = proto::UpdateWorktree {
1956            project_id: project_id.to_proto(),
1957            worktree_id,
1958            abs_path: worktree.abs_path.clone(),
1959            root_name: worktree.root_name,
1960            updated_entries: worktree.entries,
1961            removed_entries: Default::default(),
1962            scan_id: worktree.scan_id,
1963            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1964            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1965            removed_repositories: Default::default(),
1966        };
1967        for update in proto::split_worktree_update(message) {
1968            session.peer.send(session.connection_id, update.clone())?;
1969        }
1970
1971        // Stream this worktree's diagnostics.
1972        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1973        if let Some(summary) = worktree_diagnostics.next() {
1974            let message = proto::UpdateDiagnosticSummary {
1975                project_id: project.id.to_proto(),
1976                worktree_id: worktree.id,
1977                summary: Some(summary),
1978                more_summaries: worktree_diagnostics.collect(),
1979            };
1980            session.peer.send(session.connection_id, message)?;
1981        }
1982
1983        for settings_file in worktree.settings_files {
1984            session.peer.send(
1985                session.connection_id,
1986                proto::UpdateWorktreeSettings {
1987                    project_id: project_id.to_proto(),
1988                    worktree_id: worktree.id,
1989                    path: settings_file.path,
1990                    content: Some(settings_file.content),
1991                    kind: Some(settings_file.kind.to_proto() as i32),
1992                    outside_worktree: Some(settings_file.outside_worktree),
1993                },
1994            )?;
1995        }
1996    }
1997
1998    for repository in mem::take(&mut project.repositories) {
1999        for update in split_repository_update(repository) {
2000            session.peer.send(session.connection_id, update)?;
2001        }
2002    }
2003
2004    for language_server in &project.language_servers {
2005        session.peer.send(
2006            session.connection_id,
2007            proto::UpdateLanguageServer {
2008                project_id: project_id.to_proto(),
2009                server_name: Some(language_server.server.name.clone()),
2010                language_server_id: language_server.server.id,
2011                variant: Some(
2012                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2013                        proto::LspDiskBasedDiagnosticsUpdated {},
2014                    ),
2015                ),
2016            },
2017        )?;
2018    }
2019
2020    Ok(())
2021}
2022
2023/// Leave someone elses shared project.
2024async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2025    let sender_id = session.connection_id;
2026    let project_id = ProjectId::from_proto(request.project_id);
2027    let db = session.db().await;
2028
2029    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2030    tracing::info!(
2031        %project_id,
2032        "leave project"
2033    );
2034
2035    project_left(project, &session);
2036    if let Some(room) = room {
2037        room_updated(room, &session.peer);
2038    }
2039
2040    Ok(())
2041}
2042
2043/// Updates other participants with changes to the project
2044async fn update_project(
2045    request: proto::UpdateProject,
2046    response: Response<proto::UpdateProject>,
2047    session: MessageContext,
2048) -> Result<()> {
2049    let project_id = ProjectId::from_proto(request.project_id);
2050    let (room, guest_connection_ids) = &*session
2051        .db()
2052        .await
2053        .update_project(project_id, session.connection_id, &request.worktrees)
2054        .await?;
2055    broadcast(
2056        Some(session.connection_id),
2057        guest_connection_ids.iter().copied(),
2058        |connection_id| {
2059            session
2060                .peer
2061                .forward_send(session.connection_id, connection_id, request.clone())
2062        },
2063    );
2064    if let Some(room) = room {
2065        room_updated(room, &session.peer);
2066    }
2067    response.send(proto::Ack {})?;
2068
2069    Ok(())
2070}
2071
2072/// Updates other participants with changes to the worktree
2073async fn update_worktree(
2074    request: proto::UpdateWorktree,
2075    response: Response<proto::UpdateWorktree>,
2076    session: MessageContext,
2077) -> Result<()> {
2078    let guest_connection_ids = session
2079        .db()
2080        .await
2081        .update_worktree(&request, session.connection_id)
2082        .await?;
2083
2084    broadcast(
2085        Some(session.connection_id),
2086        guest_connection_ids.iter().copied(),
2087        |connection_id| {
2088            session
2089                .peer
2090                .forward_send(session.connection_id, connection_id, request.clone())
2091        },
2092    );
2093    response.send(proto::Ack {})?;
2094    Ok(())
2095}
2096
2097async fn update_repository(
2098    request: proto::UpdateRepository,
2099    response: Response<proto::UpdateRepository>,
2100    session: MessageContext,
2101) -> Result<()> {
2102    let guest_connection_ids = session
2103        .db()
2104        .await
2105        .update_repository(&request, session.connection_id)
2106        .await?;
2107
2108    broadcast(
2109        Some(session.connection_id),
2110        guest_connection_ids.iter().copied(),
2111        |connection_id| {
2112            session
2113                .peer
2114                .forward_send(session.connection_id, connection_id, request.clone())
2115        },
2116    );
2117    response.send(proto::Ack {})?;
2118    Ok(())
2119}
2120
2121async fn remove_repository(
2122    request: proto::RemoveRepository,
2123    response: Response<proto::RemoveRepository>,
2124    session: MessageContext,
2125) -> Result<()> {
2126    let guest_connection_ids = session
2127        .db()
2128        .await
2129        .remove_repository(&request, session.connection_id)
2130        .await?;
2131
2132    broadcast(
2133        Some(session.connection_id),
2134        guest_connection_ids.iter().copied(),
2135        |connection_id| {
2136            session
2137                .peer
2138                .forward_send(session.connection_id, connection_id, request.clone())
2139        },
2140    );
2141    response.send(proto::Ack {})?;
2142    Ok(())
2143}
2144
2145/// Updates other participants with changes to the diagnostics
2146async fn update_diagnostic_summary(
2147    message: proto::UpdateDiagnosticSummary,
2148    session: MessageContext,
2149) -> Result<()> {
2150    let guest_connection_ids = session
2151        .db()
2152        .await
2153        .update_diagnostic_summary(&message, session.connection_id)
2154        .await?;
2155
2156    broadcast(
2157        Some(session.connection_id),
2158        guest_connection_ids.iter().copied(),
2159        |connection_id| {
2160            session
2161                .peer
2162                .forward_send(session.connection_id, connection_id, message.clone())
2163        },
2164    );
2165
2166    Ok(())
2167}
2168
2169/// Updates other participants with changes to the worktree settings
2170async fn update_worktree_settings(
2171    message: proto::UpdateWorktreeSettings,
2172    session: MessageContext,
2173) -> Result<()> {
2174    let guest_connection_ids = session
2175        .db()
2176        .await
2177        .update_worktree_settings(&message, session.connection_id)
2178        .await?;
2179
2180    broadcast(
2181        Some(session.connection_id),
2182        guest_connection_ids.iter().copied(),
2183        |connection_id| {
2184            session
2185                .peer
2186                .forward_send(session.connection_id, connection_id, message.clone())
2187        },
2188    );
2189
2190    Ok(())
2191}
2192
2193/// Notify other participants that a language server has started.
2194async fn start_language_server(
2195    request: proto::StartLanguageServer,
2196    session: MessageContext,
2197) -> Result<()> {
2198    let guest_connection_ids = session
2199        .db()
2200        .await
2201        .start_language_server(&request, session.connection_id)
2202        .await?;
2203
2204    broadcast(
2205        Some(session.connection_id),
2206        guest_connection_ids.iter().copied(),
2207        |connection_id| {
2208            session
2209                .peer
2210                .forward_send(session.connection_id, connection_id, request.clone())
2211        },
2212    );
2213    Ok(())
2214}
2215
2216/// Notify other participants that a language server has changed.
2217async fn update_language_server(
2218    request: proto::UpdateLanguageServer,
2219    session: MessageContext,
2220) -> Result<()> {
2221    let project_id = ProjectId::from_proto(request.project_id);
2222    let db = session.db().await;
2223
2224    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2225        && let Some(capabilities) = update.capabilities.clone()
2226    {
2227        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2228            .await?;
2229    }
2230
2231    let project_connection_ids = db
2232        .project_connection_ids(project_id, session.connection_id, true)
2233        .await?;
2234    broadcast(
2235        Some(session.connection_id),
2236        project_connection_ids.iter().copied(),
2237        |connection_id| {
2238            session
2239                .peer
2240                .forward_send(session.connection_id, connection_id, request.clone())
2241        },
2242    );
2243    Ok(())
2244}
2245
2246/// forward a project request to the host. These requests should be read only
2247/// as guests are allowed to send them.
2248async fn forward_read_only_project_request<T>(
2249    request: T,
2250    response: Response<T>,
2251    session: MessageContext,
2252) -> Result<()>
2253where
2254    T: EntityMessage + RequestMessage,
2255{
2256    let project_id = ProjectId::from_proto(request.remote_entity_id());
2257    let host_connection_id = session
2258        .db()
2259        .await
2260        .host_for_read_only_project_request(project_id, session.connection_id)
2261        .await?;
2262    let payload = session.forward_request(host_connection_id, request).await?;
2263    response.send(payload)?;
2264    Ok(())
2265}
2266
2267/// forward a project request to the host. These requests are disallowed
2268/// for guests.
2269async fn forward_mutating_project_request<T>(
2270    request: T,
2271    response: Response<T>,
2272    session: MessageContext,
2273) -> Result<()>
2274where
2275    T: EntityMessage + RequestMessage,
2276{
2277    let project_id = ProjectId::from_proto(request.remote_entity_id());
2278
2279    let host_connection_id = session
2280        .db()
2281        .await
2282        .host_for_mutating_project_request(project_id, session.connection_id)
2283        .await?;
2284    let payload = session.forward_request(host_connection_id, request).await?;
2285    response.send(payload)?;
2286    Ok(())
2287}
2288
2289async fn lsp_query(
2290    request: proto::LspQuery,
2291    response: Response<proto::LspQuery>,
2292    session: MessageContext,
2293) -> Result<()> {
2294    let (name, should_write) = request.query_name_and_write_permissions();
2295    tracing::Span::current().record("lsp_query_request", name);
2296    tracing::info!("lsp_query message received");
2297    if should_write {
2298        forward_mutating_project_request(request, response, session).await
2299    } else {
2300        forward_read_only_project_request(request, response, session).await
2301    }
2302}
2303
2304/// Notify other participants that a new buffer has been created
2305async fn create_buffer_for_peer(
2306    request: proto::CreateBufferForPeer,
2307    session: MessageContext,
2308) -> Result<()> {
2309    session
2310        .db()
2311        .await
2312        .check_user_is_project_host(
2313            ProjectId::from_proto(request.project_id),
2314            session.connection_id,
2315        )
2316        .await?;
2317    let peer_id = request.peer_id.context("invalid peer id")?;
2318    session
2319        .peer
2320        .forward_send(session.connection_id, peer_id.into(), request)?;
2321    Ok(())
2322}
2323
2324/// Notify other participants that a new image has been created
2325async fn create_image_for_peer(
2326    request: proto::CreateImageForPeer,
2327    session: MessageContext,
2328) -> Result<()> {
2329    session
2330        .db()
2331        .await
2332        .check_user_is_project_host(
2333            ProjectId::from_proto(request.project_id),
2334            session.connection_id,
2335        )
2336        .await?;
2337    let peer_id = request.peer_id.context("invalid peer id")?;
2338    session
2339        .peer
2340        .forward_send(session.connection_id, peer_id.into(), request)?;
2341    Ok(())
2342}
2343
2344/// Notify other participants that a buffer has been updated. This is
2345/// allowed for guests as long as the update is limited to selections.
2346async fn update_buffer(
2347    request: proto::UpdateBuffer,
2348    response: Response<proto::UpdateBuffer>,
2349    session: MessageContext,
2350) -> Result<()> {
2351    let project_id = ProjectId::from_proto(request.project_id);
2352    let mut capability = Capability::ReadOnly;
2353
2354    for op in request.operations.iter() {
2355        match op.variant {
2356            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2357            Some(_) => capability = Capability::ReadWrite,
2358        }
2359    }
2360
2361    let host = {
2362        let guard = session
2363            .db()
2364            .await
2365            .connections_for_buffer_update(project_id, session.connection_id, capability)
2366            .await?;
2367
2368        let (host, guests) = &*guard;
2369
2370        broadcast(
2371            Some(session.connection_id),
2372            guests.clone(),
2373            |connection_id| {
2374                session
2375                    .peer
2376                    .forward_send(session.connection_id, connection_id, request.clone())
2377            },
2378        );
2379
2380        *host
2381    };
2382
2383    if host != session.connection_id {
2384        session.forward_request(host, request.clone()).await?;
2385    }
2386
2387    response.send(proto::Ack {})?;
2388    Ok(())
2389}
2390
2391async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2392    let project_id = ProjectId::from_proto(message.project_id);
2393
2394    let operation = message.operation.as_ref().context("invalid operation")?;
2395    let capability = match operation.variant.as_ref() {
2396        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2397            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2398                match buffer_op.variant {
2399                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2400                        Capability::ReadOnly
2401                    }
2402                    _ => Capability::ReadWrite,
2403                }
2404            } else {
2405                Capability::ReadWrite
2406            }
2407        }
2408        Some(_) => Capability::ReadWrite,
2409        None => Capability::ReadOnly,
2410    };
2411
2412    let guard = session
2413        .db()
2414        .await
2415        .connections_for_buffer_update(project_id, session.connection_id, capability)
2416        .await?;
2417
2418    let (host, guests) = &*guard;
2419
2420    broadcast(
2421        Some(session.connection_id),
2422        guests.iter().chain([host]).copied(),
2423        |connection_id| {
2424            session
2425                .peer
2426                .forward_send(session.connection_id, connection_id, message.clone())
2427        },
2428    );
2429
2430    Ok(())
2431}
2432
2433async fn forward_project_search_chunk(
2434    message: proto::FindSearchCandidatesChunk,
2435    response: Response<proto::FindSearchCandidatesChunk>,
2436    session: MessageContext,
2437) -> Result<()> {
2438    let peer_id = message.peer_id.context("missing peer_id")?;
2439    let payload = session
2440        .peer
2441        .forward_request(session.connection_id, peer_id.into(), message)
2442        .await?;
2443    response.send(payload)?;
2444    Ok(())
2445}
2446
2447/// Notify other participants that a project has been updated.
2448async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2449    request: T,
2450    session: MessageContext,
2451) -> Result<()> {
2452    let project_id = ProjectId::from_proto(request.remote_entity_id());
2453    let project_connection_ids = session
2454        .db()
2455        .await
2456        .project_connection_ids(project_id, session.connection_id, false)
2457        .await?;
2458
2459    broadcast(
2460        Some(session.connection_id),
2461        project_connection_ids.iter().copied(),
2462        |connection_id| {
2463            session
2464                .peer
2465                .forward_send(session.connection_id, connection_id, request.clone())
2466        },
2467    );
2468    Ok(())
2469}
2470
2471/// Start following another user in a call.
2472async fn follow(
2473    request: proto::Follow,
2474    response: Response<proto::Follow>,
2475    session: MessageContext,
2476) -> Result<()> {
2477    let room_id = RoomId::from_proto(request.room_id);
2478    let project_id = request.project_id.map(ProjectId::from_proto);
2479    let leader_id = request.leader_id.context("invalid leader id")?.into();
2480    let follower_id = session.connection_id;
2481
2482    session
2483        .db()
2484        .await
2485        .check_room_participants(room_id, leader_id, session.connection_id)
2486        .await?;
2487
2488    let response_payload = session.forward_request(leader_id, request).await?;
2489    response.send(response_payload)?;
2490
2491    if let Some(project_id) = project_id {
2492        let room = session
2493            .db()
2494            .await
2495            .follow(room_id, project_id, leader_id, follower_id)
2496            .await?;
2497        room_updated(&room, &session.peer);
2498    }
2499
2500    Ok(())
2501}
2502
2503/// Stop following another user in a call.
2504async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2505    let room_id = RoomId::from_proto(request.room_id);
2506    let project_id = request.project_id.map(ProjectId::from_proto);
2507    let leader_id = request.leader_id.context("invalid leader id")?.into();
2508    let follower_id = session.connection_id;
2509
2510    session
2511        .db()
2512        .await
2513        .check_room_participants(room_id, leader_id, session.connection_id)
2514        .await?;
2515
2516    session
2517        .peer
2518        .forward_send(session.connection_id, leader_id, request)?;
2519
2520    if let Some(project_id) = project_id {
2521        let room = session
2522            .db()
2523            .await
2524            .unfollow(room_id, project_id, leader_id, follower_id)
2525            .await?;
2526        room_updated(&room, &session.peer);
2527    }
2528
2529    Ok(())
2530}
2531
2532/// Notify everyone following you of your current location.
2533async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2534    let room_id = RoomId::from_proto(request.room_id);
2535    let database = session.db.lock().await;
2536
2537    let connection_ids = if let Some(project_id) = request.project_id {
2538        let project_id = ProjectId::from_proto(project_id);
2539        database
2540            .project_connection_ids(project_id, session.connection_id, true)
2541            .await?
2542    } else {
2543        database
2544            .room_connection_ids(room_id, session.connection_id)
2545            .await?
2546    };
2547
2548    // For now, don't send view update messages back to that view's current leader.
2549    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2550        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2551        _ => None,
2552    });
2553
2554    for connection_id in connection_ids.iter().cloned() {
2555        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2556            session
2557                .peer
2558                .forward_send(session.connection_id, connection_id, request.clone())?;
2559        }
2560    }
2561    Ok(())
2562}
2563
2564/// Get public data about users.
2565async fn get_users(
2566    request: proto::GetUsers,
2567    response: Response<proto::GetUsers>,
2568    session: MessageContext,
2569) -> Result<()> {
2570    let user_ids = request
2571        .user_ids
2572        .into_iter()
2573        .map(UserId::from_proto)
2574        .collect();
2575    let users = session
2576        .db()
2577        .await
2578        .get_users_by_ids(user_ids)
2579        .await?
2580        .into_iter()
2581        .map(|user| proto::User {
2582            id: user.id.to_proto(),
2583            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2584            github_login: user.github_login,
2585            name: user.name,
2586        })
2587        .collect();
2588    response.send(proto::UsersResponse { users })?;
2589    Ok(())
2590}
2591
2592/// Search for users (to invite) buy Github login
2593async fn fuzzy_search_users(
2594    request: proto::FuzzySearchUsers,
2595    response: Response<proto::FuzzySearchUsers>,
2596    session: MessageContext,
2597) -> Result<()> {
2598    let query = request.query;
2599    let users = match query.len() {
2600        0 => vec![],
2601        1 | 2 => session
2602            .db()
2603            .await
2604            .get_user_by_github_login(&query)
2605            .await?
2606            .into_iter()
2607            .collect(),
2608        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2609    };
2610    let users = users
2611        .into_iter()
2612        .filter(|user| user.id != session.user_id())
2613        .map(|user| proto::User {
2614            id: user.id.to_proto(),
2615            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2616            github_login: user.github_login,
2617            name: user.name,
2618        })
2619        .collect();
2620    response.send(proto::UsersResponse { users })?;
2621    Ok(())
2622}
2623
2624/// Send a contact request to another user.
2625async fn request_contact(
2626    request: proto::RequestContact,
2627    response: Response<proto::RequestContact>,
2628    session: MessageContext,
2629) -> Result<()> {
2630    let requester_id = session.user_id();
2631    let responder_id = UserId::from_proto(request.responder_id);
2632    if requester_id == responder_id {
2633        return Err(anyhow!("cannot add yourself as a contact"))?;
2634    }
2635
2636    let notifications = session
2637        .db()
2638        .await
2639        .send_contact_request(requester_id, responder_id)
2640        .await?;
2641
2642    // Update outgoing contact requests of requester
2643    let mut update = proto::UpdateContacts::default();
2644    update.outgoing_requests.push(responder_id.to_proto());
2645    for connection_id in session
2646        .connection_pool()
2647        .await
2648        .user_connection_ids(requester_id)
2649    {
2650        session.peer.send(connection_id, update.clone())?;
2651    }
2652
2653    // Update incoming contact requests of responder
2654    let mut update = proto::UpdateContacts::default();
2655    update
2656        .incoming_requests
2657        .push(proto::IncomingContactRequest {
2658            requester_id: requester_id.to_proto(),
2659        });
2660    let connection_pool = session.connection_pool().await;
2661    for connection_id in connection_pool.user_connection_ids(responder_id) {
2662        session.peer.send(connection_id, update.clone())?;
2663    }
2664
2665    send_notifications(&connection_pool, &session.peer, notifications);
2666
2667    response.send(proto::Ack {})?;
2668    Ok(())
2669}
2670
2671/// Accept or decline a contact request
2672async fn respond_to_contact_request(
2673    request: proto::RespondToContactRequest,
2674    response: Response<proto::RespondToContactRequest>,
2675    session: MessageContext,
2676) -> Result<()> {
2677    let responder_id = session.user_id();
2678    let requester_id = UserId::from_proto(request.requester_id);
2679    let db = session.db().await;
2680    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2681        db.dismiss_contact_notification(responder_id, requester_id)
2682            .await?;
2683    } else {
2684        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2685
2686        let notifications = db
2687            .respond_to_contact_request(responder_id, requester_id, accept)
2688            .await?;
2689        let requester_busy = db.is_user_busy(requester_id).await?;
2690        let responder_busy = db.is_user_busy(responder_id).await?;
2691
2692        let pool = session.connection_pool().await;
2693        // Update responder with new contact
2694        let mut update = proto::UpdateContacts::default();
2695        if accept {
2696            update
2697                .contacts
2698                .push(contact_for_user(requester_id, requester_busy, &pool));
2699        }
2700        update
2701            .remove_incoming_requests
2702            .push(requester_id.to_proto());
2703        for connection_id in pool.user_connection_ids(responder_id) {
2704            session.peer.send(connection_id, update.clone())?;
2705        }
2706
2707        // Update requester with new contact
2708        let mut update = proto::UpdateContacts::default();
2709        if accept {
2710            update
2711                .contacts
2712                .push(contact_for_user(responder_id, responder_busy, &pool));
2713        }
2714        update
2715            .remove_outgoing_requests
2716            .push(responder_id.to_proto());
2717
2718        for connection_id in pool.user_connection_ids(requester_id) {
2719            session.peer.send(connection_id, update.clone())?;
2720        }
2721
2722        send_notifications(&pool, &session.peer, notifications);
2723    }
2724
2725    response.send(proto::Ack {})?;
2726    Ok(())
2727}
2728
2729/// Remove a contact.
2730async fn remove_contact(
2731    request: proto::RemoveContact,
2732    response: Response<proto::RemoveContact>,
2733    session: MessageContext,
2734) -> Result<()> {
2735    let requester_id = session.user_id();
2736    let responder_id = UserId::from_proto(request.user_id);
2737    let db = session.db().await;
2738    let (contact_accepted, deleted_notification_id) =
2739        db.remove_contact(requester_id, responder_id).await?;
2740
2741    let pool = session.connection_pool().await;
2742    // Update outgoing contact requests of requester
2743    let mut update = proto::UpdateContacts::default();
2744    if contact_accepted {
2745        update.remove_contacts.push(responder_id.to_proto());
2746    } else {
2747        update
2748            .remove_outgoing_requests
2749            .push(responder_id.to_proto());
2750    }
2751    for connection_id in pool.user_connection_ids(requester_id) {
2752        session.peer.send(connection_id, update.clone())?;
2753    }
2754
2755    // Update incoming contact requests of responder
2756    let mut update = proto::UpdateContacts::default();
2757    if contact_accepted {
2758        update.remove_contacts.push(requester_id.to_proto());
2759    } else {
2760        update
2761            .remove_incoming_requests
2762            .push(requester_id.to_proto());
2763    }
2764    for connection_id in pool.user_connection_ids(responder_id) {
2765        session.peer.send(connection_id, update.clone())?;
2766        if let Some(notification_id) = deleted_notification_id {
2767            session.peer.send(
2768                connection_id,
2769                proto::DeleteNotification {
2770                    notification_id: notification_id.to_proto(),
2771                },
2772            )?;
2773        }
2774    }
2775
2776    response.send(proto::Ack {})?;
2777    Ok(())
2778}
2779
2780fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2781    version.0.minor < 139
2782}
2783
2784async fn subscribe_to_channels(
2785    _: proto::SubscribeToChannels,
2786    session: MessageContext,
2787) -> Result<()> {
2788    subscribe_user_to_channels(session.user_id(), &session).await?;
2789    Ok(())
2790}
2791
2792async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2793    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2794    let mut pool = session.connection_pool().await;
2795    for membership in &channels_for_user.channel_memberships {
2796        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2797    }
2798    session.peer.send(
2799        session.connection_id,
2800        build_update_user_channels(&channels_for_user),
2801    )?;
2802    session.peer.send(
2803        session.connection_id,
2804        build_channels_update(channels_for_user),
2805    )?;
2806    Ok(())
2807}
2808
2809/// Creates a new channel.
2810async fn create_channel(
2811    request: proto::CreateChannel,
2812    response: Response<proto::CreateChannel>,
2813    session: MessageContext,
2814) -> Result<()> {
2815    let db = session.db().await;
2816
2817    let parent_id = request.parent_id.map(ChannelId::from_proto);
2818    let (channel, membership) = db
2819        .create_channel(&request.name, parent_id, session.user_id())
2820        .await?;
2821
2822    let root_id = channel.root_id();
2823    let channel = Channel::from_model(channel);
2824
2825    response.send(proto::CreateChannelResponse {
2826        channel: Some(channel.to_proto()),
2827        parent_id: request.parent_id,
2828    })?;
2829
2830    let mut connection_pool = session.connection_pool().await;
2831    if let Some(membership) = membership {
2832        connection_pool.subscribe_to_channel(
2833            membership.user_id,
2834            membership.channel_id,
2835            membership.role,
2836        );
2837        let update = proto::UpdateUserChannels {
2838            channel_memberships: vec![proto::ChannelMembership {
2839                channel_id: membership.channel_id.to_proto(),
2840                role: membership.role.into(),
2841            }],
2842            ..Default::default()
2843        };
2844        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2845            session.peer.send(connection_id, update.clone())?;
2846        }
2847    }
2848
2849    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2850        if !role.can_see_channel(channel.visibility) {
2851            continue;
2852        }
2853
2854        let update = proto::UpdateChannels {
2855            channels: vec![channel.to_proto()],
2856            ..Default::default()
2857        };
2858        session.peer.send(connection_id, update.clone())?;
2859    }
2860
2861    Ok(())
2862}
2863
2864/// Delete a channel
2865async fn delete_channel(
2866    request: proto::DeleteChannel,
2867    response: Response<proto::DeleteChannel>,
2868    session: MessageContext,
2869) -> Result<()> {
2870    let db = session.db().await;
2871
2872    let channel_id = request.channel_id;
2873    let (root_channel, removed_channels) = db
2874        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2875        .await?;
2876    response.send(proto::Ack {})?;
2877
2878    // Notify members of removed channels
2879    let mut update = proto::UpdateChannels::default();
2880    update
2881        .delete_channels
2882        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2883
2884    let connection_pool = session.connection_pool().await;
2885    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2886        session.peer.send(connection_id, update.clone())?;
2887    }
2888
2889    Ok(())
2890}
2891
2892/// Invite someone to join a channel.
2893async fn invite_channel_member(
2894    request: proto::InviteChannelMember,
2895    response: Response<proto::InviteChannelMember>,
2896    session: MessageContext,
2897) -> Result<()> {
2898    let db = session.db().await;
2899    let channel_id = ChannelId::from_proto(request.channel_id);
2900    let invitee_id = UserId::from_proto(request.user_id);
2901    let InviteMemberResult {
2902        channel,
2903        notifications,
2904    } = db
2905        .invite_channel_member(
2906            channel_id,
2907            invitee_id,
2908            session.user_id(),
2909            request.role().into(),
2910        )
2911        .await?;
2912
2913    let update = proto::UpdateChannels {
2914        channel_invitations: vec![channel.to_proto()],
2915        ..Default::default()
2916    };
2917
2918    let connection_pool = session.connection_pool().await;
2919    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2920        session.peer.send(connection_id, update.clone())?;
2921    }
2922
2923    send_notifications(&connection_pool, &session.peer, notifications);
2924
2925    response.send(proto::Ack {})?;
2926    Ok(())
2927}
2928
2929/// remove someone from a channel
2930async fn remove_channel_member(
2931    request: proto::RemoveChannelMember,
2932    response: Response<proto::RemoveChannelMember>,
2933    session: MessageContext,
2934) -> Result<()> {
2935    let db = session.db().await;
2936    let channel_id = ChannelId::from_proto(request.channel_id);
2937    let member_id = UserId::from_proto(request.user_id);
2938
2939    let RemoveChannelMemberResult {
2940        membership_update,
2941        notification_id,
2942    } = db
2943        .remove_channel_member(channel_id, member_id, session.user_id())
2944        .await?;
2945
2946    let mut connection_pool = session.connection_pool().await;
2947    notify_membership_updated(
2948        &mut connection_pool,
2949        membership_update,
2950        member_id,
2951        &session.peer,
2952    );
2953    for connection_id in connection_pool.user_connection_ids(member_id) {
2954        if let Some(notification_id) = notification_id {
2955            session
2956                .peer
2957                .send(
2958                    connection_id,
2959                    proto::DeleteNotification {
2960                        notification_id: notification_id.to_proto(),
2961                    },
2962                )
2963                .trace_err();
2964        }
2965    }
2966
2967    response.send(proto::Ack {})?;
2968    Ok(())
2969}
2970
2971/// Toggle the channel between public and private.
2972/// Care is taken to maintain the invariant that public channels only descend from public channels,
2973/// (though members-only channels can appear at any point in the hierarchy).
2974async fn set_channel_visibility(
2975    request: proto::SetChannelVisibility,
2976    response: Response<proto::SetChannelVisibility>,
2977    session: MessageContext,
2978) -> Result<()> {
2979    let db = session.db().await;
2980    let channel_id = ChannelId::from_proto(request.channel_id);
2981    let visibility = request.visibility().into();
2982
2983    let channel_model = db
2984        .set_channel_visibility(channel_id, visibility, session.user_id())
2985        .await?;
2986    let root_id = channel_model.root_id();
2987    let channel = Channel::from_model(channel_model);
2988
2989    let mut connection_pool = session.connection_pool().await;
2990    for (user_id, role) in connection_pool
2991        .channel_user_ids(root_id)
2992        .collect::<Vec<_>>()
2993        .into_iter()
2994    {
2995        let update = if role.can_see_channel(channel.visibility) {
2996            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2997            proto::UpdateChannels {
2998                channels: vec![channel.to_proto()],
2999                ..Default::default()
3000            }
3001        } else {
3002            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3003            proto::UpdateChannels {
3004                delete_channels: vec![channel.id.to_proto()],
3005                ..Default::default()
3006            }
3007        };
3008
3009        for connection_id in connection_pool.user_connection_ids(user_id) {
3010            session.peer.send(connection_id, update.clone())?;
3011        }
3012    }
3013
3014    response.send(proto::Ack {})?;
3015    Ok(())
3016}
3017
3018/// Alter the role for a user in the channel.
3019async fn set_channel_member_role(
3020    request: proto::SetChannelMemberRole,
3021    response: Response<proto::SetChannelMemberRole>,
3022    session: MessageContext,
3023) -> Result<()> {
3024    let db = session.db().await;
3025    let channel_id = ChannelId::from_proto(request.channel_id);
3026    let member_id = UserId::from_proto(request.user_id);
3027    let result = db
3028        .set_channel_member_role(
3029            channel_id,
3030            session.user_id(),
3031            member_id,
3032            request.role().into(),
3033        )
3034        .await?;
3035
3036    match result {
3037        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3038            let mut connection_pool = session.connection_pool().await;
3039            notify_membership_updated(
3040                &mut connection_pool,
3041                membership_update,
3042                member_id,
3043                &session.peer,
3044            )
3045        }
3046        db::SetMemberRoleResult::InviteUpdated(channel) => {
3047            let update = proto::UpdateChannels {
3048                channel_invitations: vec![channel.to_proto()],
3049                ..Default::default()
3050            };
3051
3052            for connection_id in session
3053                .connection_pool()
3054                .await
3055                .user_connection_ids(member_id)
3056            {
3057                session.peer.send(connection_id, update.clone())?;
3058            }
3059        }
3060    }
3061
3062    response.send(proto::Ack {})?;
3063    Ok(())
3064}
3065
3066/// Change the name of a channel
3067async fn rename_channel(
3068    request: proto::RenameChannel,
3069    response: Response<proto::RenameChannel>,
3070    session: MessageContext,
3071) -> Result<()> {
3072    let db = session.db().await;
3073    let channel_id = ChannelId::from_proto(request.channel_id);
3074    let channel_model = db
3075        .rename_channel(channel_id, session.user_id(), &request.name)
3076        .await?;
3077    let root_id = channel_model.root_id();
3078    let channel = Channel::from_model(channel_model);
3079
3080    response.send(proto::RenameChannelResponse {
3081        channel: Some(channel.to_proto()),
3082    })?;
3083
3084    let connection_pool = session.connection_pool().await;
3085    let update = proto::UpdateChannels {
3086        channels: vec![channel.to_proto()],
3087        ..Default::default()
3088    };
3089    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3090        if role.can_see_channel(channel.visibility) {
3091            session.peer.send(connection_id, update.clone())?;
3092        }
3093    }
3094
3095    Ok(())
3096}
3097
3098/// Move a channel to a new parent.
3099async fn move_channel(
3100    request: proto::MoveChannel,
3101    response: Response<proto::MoveChannel>,
3102    session: MessageContext,
3103) -> Result<()> {
3104    let channel_id = ChannelId::from_proto(request.channel_id);
3105    let to = ChannelId::from_proto(request.to);
3106
3107    let (root_id, channels) = session
3108        .db()
3109        .await
3110        .move_channel(channel_id, to, session.user_id())
3111        .await?;
3112
3113    let connection_pool = session.connection_pool().await;
3114    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3115        let channels = channels
3116            .iter()
3117            .filter_map(|channel| {
3118                if role.can_see_channel(channel.visibility) {
3119                    Some(channel.to_proto())
3120                } else {
3121                    None
3122                }
3123            })
3124            .collect::<Vec<_>>();
3125        if channels.is_empty() {
3126            continue;
3127        }
3128
3129        let update = proto::UpdateChannels {
3130            channels,
3131            ..Default::default()
3132        };
3133
3134        session.peer.send(connection_id, update.clone())?;
3135    }
3136
3137    response.send(Ack {})?;
3138    Ok(())
3139}
3140
3141async fn reorder_channel(
3142    request: proto::ReorderChannel,
3143    response: Response<proto::ReorderChannel>,
3144    session: MessageContext,
3145) -> Result<()> {
3146    let channel_id = ChannelId::from_proto(request.channel_id);
3147    let direction = request.direction();
3148
3149    let updated_channels = session
3150        .db()
3151        .await
3152        .reorder_channel(channel_id, direction, session.user_id())
3153        .await?;
3154
3155    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3156        let connection_pool = session.connection_pool().await;
3157        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3158            let channels = updated_channels
3159                .iter()
3160                .filter_map(|channel| {
3161                    if role.can_see_channel(channel.visibility) {
3162                        Some(channel.to_proto())
3163                    } else {
3164                        None
3165                    }
3166                })
3167                .collect::<Vec<_>>();
3168
3169            if channels.is_empty() {
3170                continue;
3171            }
3172
3173            let update = proto::UpdateChannels {
3174                channels,
3175                ..Default::default()
3176            };
3177
3178            session.peer.send(connection_id, update.clone())?;
3179        }
3180    }
3181
3182    response.send(Ack {})?;
3183    Ok(())
3184}
3185
3186/// Get the list of channel members
3187async fn get_channel_members(
3188    request: proto::GetChannelMembers,
3189    response: Response<proto::GetChannelMembers>,
3190    session: MessageContext,
3191) -> Result<()> {
3192    let db = session.db().await;
3193    let channel_id = ChannelId::from_proto(request.channel_id);
3194    let limit = if request.limit == 0 {
3195        u16::MAX as u64
3196    } else {
3197        request.limit
3198    };
3199    let (members, users) = db
3200        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3201        .await?;
3202    response.send(proto::GetChannelMembersResponse { members, users })?;
3203    Ok(())
3204}
3205
3206/// Accept or decline a channel invitation.
3207async fn respond_to_channel_invite(
3208    request: proto::RespondToChannelInvite,
3209    response: Response<proto::RespondToChannelInvite>,
3210    session: MessageContext,
3211) -> Result<()> {
3212    let db = session.db().await;
3213    let channel_id = ChannelId::from_proto(request.channel_id);
3214    let RespondToChannelInvite {
3215        membership_update,
3216        notifications,
3217    } = db
3218        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3219        .await?;
3220
3221    let mut connection_pool = session.connection_pool().await;
3222    if let Some(membership_update) = membership_update {
3223        notify_membership_updated(
3224            &mut connection_pool,
3225            membership_update,
3226            session.user_id(),
3227            &session.peer,
3228        );
3229    } else {
3230        let update = proto::UpdateChannels {
3231            remove_channel_invitations: vec![channel_id.to_proto()],
3232            ..Default::default()
3233        };
3234
3235        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3236            session.peer.send(connection_id, update.clone())?;
3237        }
3238    };
3239
3240    send_notifications(&connection_pool, &session.peer, notifications);
3241
3242    response.send(proto::Ack {})?;
3243
3244    Ok(())
3245}
3246
3247/// Join the channels' room
3248async fn join_channel(
3249    request: proto::JoinChannel,
3250    response: Response<proto::JoinChannel>,
3251    session: MessageContext,
3252) -> Result<()> {
3253    let channel_id = ChannelId::from_proto(request.channel_id);
3254    join_channel_internal(channel_id, Box::new(response), session).await
3255}
3256
3257trait JoinChannelInternalResponse {
3258    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3259}
3260impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3261    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3262        Response::<proto::JoinChannel>::send(self, result)
3263    }
3264}
3265impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3266    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3267        Response::<proto::JoinRoom>::send(self, result)
3268    }
3269}
3270
3271async fn join_channel_internal(
3272    channel_id: ChannelId,
3273    response: Box<impl JoinChannelInternalResponse>,
3274    session: MessageContext,
3275) -> Result<()> {
3276    let joined_room = {
3277        let mut db = session.db().await;
3278        // If zed quits without leaving the room, and the user re-opens zed before the
3279        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3280        // room they were in.
3281        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3282            tracing::info!(
3283                stale_connection_id = %connection,
3284                "cleaning up stale connection",
3285            );
3286            drop(db);
3287            leave_room_for_session(&session, connection).await?;
3288            db = session.db().await;
3289        }
3290
3291        let (joined_room, membership_updated, role) = db
3292            .join_channel(channel_id, session.user_id(), session.connection_id)
3293            .await?;
3294
3295        let live_kit_connection_info =
3296            session
3297                .app_state
3298                .livekit_client
3299                .as_ref()
3300                .and_then(|live_kit| {
3301                    let (can_publish, token) = if role == ChannelRole::Guest {
3302                        (
3303                            false,
3304                            live_kit
3305                                .guest_token(
3306                                    &joined_room.room.livekit_room,
3307                                    &session.user_id().to_string(),
3308                                )
3309                                .trace_err()?,
3310                        )
3311                    } else {
3312                        (
3313                            true,
3314                            live_kit
3315                                .room_token(
3316                                    &joined_room.room.livekit_room,
3317                                    &session.user_id().to_string(),
3318                                )
3319                                .trace_err()?,
3320                        )
3321                    };
3322
3323                    Some(LiveKitConnectionInfo {
3324                        server_url: live_kit.url().into(),
3325                        token,
3326                        can_publish,
3327                    })
3328                });
3329
3330        response.send(proto::JoinRoomResponse {
3331            room: Some(joined_room.room.clone()),
3332            channel_id: joined_room
3333                .channel
3334                .as_ref()
3335                .map(|channel| channel.id.to_proto()),
3336            live_kit_connection_info,
3337        })?;
3338
3339        let mut connection_pool = session.connection_pool().await;
3340        if let Some(membership_updated) = membership_updated {
3341            notify_membership_updated(
3342                &mut connection_pool,
3343                membership_updated,
3344                session.user_id(),
3345                &session.peer,
3346            );
3347        }
3348
3349        room_updated(&joined_room.room, &session.peer);
3350
3351        joined_room
3352    };
3353
3354    channel_updated(
3355        &joined_room.channel.context("channel not returned")?,
3356        &joined_room.room,
3357        &session.peer,
3358        &*session.connection_pool().await,
3359    );
3360
3361    update_user_contacts(session.user_id(), &session).await?;
3362    Ok(())
3363}
3364
3365/// Start editing the channel notes
3366async fn join_channel_buffer(
3367    request: proto::JoinChannelBuffer,
3368    response: Response<proto::JoinChannelBuffer>,
3369    session: MessageContext,
3370) -> Result<()> {
3371    let db = session.db().await;
3372    let channel_id = ChannelId::from_proto(request.channel_id);
3373
3374    let open_response = db
3375        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3376        .await?;
3377
3378    let collaborators = open_response.collaborators.clone();
3379    response.send(open_response)?;
3380
3381    let update = UpdateChannelBufferCollaborators {
3382        channel_id: channel_id.to_proto(),
3383        collaborators: collaborators.clone(),
3384    };
3385    channel_buffer_updated(
3386        session.connection_id,
3387        collaborators
3388            .iter()
3389            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3390        &update,
3391        &session.peer,
3392    );
3393
3394    Ok(())
3395}
3396
3397/// Edit the channel notes
3398async fn update_channel_buffer(
3399    request: proto::UpdateChannelBuffer,
3400    session: MessageContext,
3401) -> Result<()> {
3402    let db = session.db().await;
3403    let channel_id = ChannelId::from_proto(request.channel_id);
3404
3405    let (collaborators, epoch, version) = db
3406        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3407        .await?;
3408
3409    channel_buffer_updated(
3410        session.connection_id,
3411        collaborators.clone(),
3412        &proto::UpdateChannelBuffer {
3413            channel_id: channel_id.to_proto(),
3414            operations: request.operations,
3415        },
3416        &session.peer,
3417    );
3418
3419    let pool = &*session.connection_pool().await;
3420
3421    let non_collaborators =
3422        pool.channel_connection_ids(channel_id)
3423            .filter_map(|(connection_id, _)| {
3424                if collaborators.contains(&connection_id) {
3425                    None
3426                } else {
3427                    Some(connection_id)
3428                }
3429            });
3430
3431    broadcast(None, non_collaborators, |peer_id| {
3432        session.peer.send(
3433            peer_id,
3434            proto::UpdateChannels {
3435                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3436                    channel_id: channel_id.to_proto(),
3437                    epoch: epoch as u64,
3438                    version: version.clone(),
3439                }],
3440                ..Default::default()
3441            },
3442        )
3443    });
3444
3445    Ok(())
3446}
3447
3448/// Rejoin the channel notes after a connection blip
3449async fn rejoin_channel_buffers(
3450    request: proto::RejoinChannelBuffers,
3451    response: Response<proto::RejoinChannelBuffers>,
3452    session: MessageContext,
3453) -> Result<()> {
3454    let db = session.db().await;
3455    let buffers = db
3456        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3457        .await?;
3458
3459    for rejoined_buffer in &buffers {
3460        let collaborators_to_notify = rejoined_buffer
3461            .buffer
3462            .collaborators
3463            .iter()
3464            .filter_map(|c| Some(c.peer_id?.into()));
3465        channel_buffer_updated(
3466            session.connection_id,
3467            collaborators_to_notify,
3468            &proto::UpdateChannelBufferCollaborators {
3469                channel_id: rejoined_buffer.buffer.channel_id,
3470                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3471            },
3472            &session.peer,
3473        );
3474    }
3475
3476    response.send(proto::RejoinChannelBuffersResponse {
3477        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3478    })?;
3479
3480    Ok(())
3481}
3482
3483/// Stop editing the channel notes
3484async fn leave_channel_buffer(
3485    request: proto::LeaveChannelBuffer,
3486    response: Response<proto::LeaveChannelBuffer>,
3487    session: MessageContext,
3488) -> Result<()> {
3489    let db = session.db().await;
3490    let channel_id = ChannelId::from_proto(request.channel_id);
3491
3492    let left_buffer = db
3493        .leave_channel_buffer(channel_id, session.connection_id)
3494        .await?;
3495
3496    response.send(Ack {})?;
3497
3498    channel_buffer_updated(
3499        session.connection_id,
3500        left_buffer.connections,
3501        &proto::UpdateChannelBufferCollaborators {
3502            channel_id: channel_id.to_proto(),
3503            collaborators: left_buffer.collaborators,
3504        },
3505        &session.peer,
3506    );
3507
3508    Ok(())
3509}
3510
3511fn channel_buffer_updated<T: EnvelopedMessage>(
3512    sender_id: ConnectionId,
3513    collaborators: impl IntoIterator<Item = ConnectionId>,
3514    message: &T,
3515    peer: &Peer,
3516) {
3517    broadcast(Some(sender_id), collaborators, |peer_id| {
3518        peer.send(peer_id, message.clone())
3519    });
3520}
3521
3522fn send_notifications(
3523    connection_pool: &ConnectionPool,
3524    peer: &Peer,
3525    notifications: db::NotificationBatch,
3526) {
3527    for (user_id, notification) in notifications {
3528        for connection_id in connection_pool.user_connection_ids(user_id) {
3529            if let Err(error) = peer.send(
3530                connection_id,
3531                proto::AddNotification {
3532                    notification: Some(notification.clone()),
3533                },
3534            ) {
3535                tracing::error!(
3536                    "failed to send notification to {:?} {}",
3537                    connection_id,
3538                    error
3539                );
3540            }
3541        }
3542    }
3543}
3544
3545/// Send a message to the channel
3546async fn send_channel_message(
3547    _request: proto::SendChannelMessage,
3548    _response: Response<proto::SendChannelMessage>,
3549    _session: MessageContext,
3550) -> Result<()> {
3551    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3552}
3553
3554/// Delete a channel message
3555async fn remove_channel_message(
3556    _request: proto::RemoveChannelMessage,
3557    _response: Response<proto::RemoveChannelMessage>,
3558    _session: MessageContext,
3559) -> Result<()> {
3560    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3561}
3562
3563async fn update_channel_message(
3564    _request: proto::UpdateChannelMessage,
3565    _response: Response<proto::UpdateChannelMessage>,
3566    _session: MessageContext,
3567) -> Result<()> {
3568    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3569}
3570
3571/// Mark a channel message as read
3572async fn acknowledge_channel_message(
3573    _request: proto::AckChannelMessage,
3574    _session: MessageContext,
3575) -> Result<()> {
3576    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3577}
3578
3579/// Mark a buffer version as synced
3580async fn acknowledge_buffer_version(
3581    request: proto::AckBufferOperation,
3582    session: MessageContext,
3583) -> Result<()> {
3584    let buffer_id = BufferId::from_proto(request.buffer_id);
3585    session
3586        .db()
3587        .await
3588        .observe_buffer_version(
3589            buffer_id,
3590            session.user_id(),
3591            request.epoch as i32,
3592            &request.version,
3593        )
3594        .await?;
3595    Ok(())
3596}
3597
3598/// Start receiving chat updates for a channel
3599async fn join_channel_chat(
3600    _request: proto::JoinChannelChat,
3601    _response: Response<proto::JoinChannelChat>,
3602    _session: MessageContext,
3603) -> Result<()> {
3604    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3605}
3606
3607/// Stop receiving chat updates for a channel
3608async fn leave_channel_chat(
3609    _request: proto::LeaveChannelChat,
3610    _session: MessageContext,
3611) -> Result<()> {
3612    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3613}
3614
3615/// Retrieve the chat history for a channel
3616async fn get_channel_messages(
3617    _request: proto::GetChannelMessages,
3618    _response: Response<proto::GetChannelMessages>,
3619    _session: MessageContext,
3620) -> Result<()> {
3621    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3622}
3623
3624/// Retrieve specific chat messages
3625async fn get_channel_messages_by_id(
3626    _request: proto::GetChannelMessagesById,
3627    _response: Response<proto::GetChannelMessagesById>,
3628    _session: MessageContext,
3629) -> Result<()> {
3630    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3631}
3632
3633/// Retrieve the current users notifications
3634async fn get_notifications(
3635    request: proto::GetNotifications,
3636    response: Response<proto::GetNotifications>,
3637    session: MessageContext,
3638) -> Result<()> {
3639    let notifications = session
3640        .db()
3641        .await
3642        .get_notifications(
3643            session.user_id(),
3644            NOTIFICATION_COUNT_PER_PAGE,
3645            request.before_id.map(db::NotificationId::from_proto),
3646        )
3647        .await?;
3648    response.send(proto::GetNotificationsResponse {
3649        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3650        notifications,
3651    })?;
3652    Ok(())
3653}
3654
3655/// Mark notifications as read
3656async fn mark_notification_as_read(
3657    request: proto::MarkNotificationRead,
3658    response: Response<proto::MarkNotificationRead>,
3659    session: MessageContext,
3660) -> Result<()> {
3661    let database = &session.db().await;
3662    let notifications = database
3663        .mark_notification_as_read_by_id(
3664            session.user_id(),
3665            NotificationId::from_proto(request.notification_id),
3666        )
3667        .await?;
3668    send_notifications(
3669        &*session.connection_pool().await,
3670        &session.peer,
3671        notifications,
3672    );
3673    response.send(proto::Ack {})?;
3674    Ok(())
3675}
3676
3677fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3678    let message = match message {
3679        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3680        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3681        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3682        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3683        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3684            code: frame.code.into(),
3685            reason: frame.reason.as_str().to_owned().into(),
3686        })),
3687        // We should never receive a frame while reading the message, according
3688        // to the `tungstenite` maintainers:
3689        //
3690        // > It cannot occur when you read messages from the WebSocket, but it
3691        // > can be used when you want to send the raw frames (e.g. you want to
3692        // > send the frames to the WebSocket without composing the full message first).
3693        // >
3694        // > — https://github.com/snapview/tungstenite-rs/issues/268
3695        TungsteniteMessage::Frame(_) => {
3696            bail!("received an unexpected frame while reading the message")
3697        }
3698    };
3699
3700    Ok(message)
3701}
3702
3703fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3704    match message {
3705        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3706        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3707        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3708        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3709        AxumMessage::Close(frame) => {
3710            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3711                code: frame.code.into(),
3712                reason: frame.reason.as_ref().into(),
3713            }))
3714        }
3715    }
3716}
3717
3718fn notify_membership_updated(
3719    connection_pool: &mut ConnectionPool,
3720    result: MembershipUpdated,
3721    user_id: UserId,
3722    peer: &Peer,
3723) {
3724    for membership in &result.new_channels.channel_memberships {
3725        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3726    }
3727    for channel_id in &result.removed_channels {
3728        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3729    }
3730
3731    let user_channels_update = proto::UpdateUserChannels {
3732        channel_memberships: result
3733            .new_channels
3734            .channel_memberships
3735            .iter()
3736            .map(|cm| proto::ChannelMembership {
3737                channel_id: cm.channel_id.to_proto(),
3738                role: cm.role.into(),
3739            })
3740            .collect(),
3741        ..Default::default()
3742    };
3743
3744    let mut update = build_channels_update(result.new_channels);
3745    update.delete_channels = result
3746        .removed_channels
3747        .into_iter()
3748        .map(|id| id.to_proto())
3749        .collect();
3750    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3751
3752    for connection_id in connection_pool.user_connection_ids(user_id) {
3753        peer.send(connection_id, user_channels_update.clone())
3754            .trace_err();
3755        peer.send(connection_id, update.clone()).trace_err();
3756    }
3757}
3758
3759fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3760    proto::UpdateUserChannels {
3761        channel_memberships: channels
3762            .channel_memberships
3763            .iter()
3764            .map(|m| proto::ChannelMembership {
3765                channel_id: m.channel_id.to_proto(),
3766                role: m.role.into(),
3767            })
3768            .collect(),
3769        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3770    }
3771}
3772
3773fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3774    let mut update = proto::UpdateChannels::default();
3775
3776    for channel in channels.channels {
3777        update.channels.push(channel.to_proto());
3778    }
3779
3780    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3781
3782    for (channel_id, participants) in channels.channel_participants {
3783        update
3784            .channel_participants
3785            .push(proto::ChannelParticipants {
3786                channel_id: channel_id.to_proto(),
3787                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3788            });
3789    }
3790
3791    for channel in channels.invited_channels {
3792        update.channel_invitations.push(channel.to_proto());
3793    }
3794
3795    update
3796}
3797
3798fn build_initial_contacts_update(
3799    contacts: Vec<db::Contact>,
3800    pool: &ConnectionPool,
3801) -> proto::UpdateContacts {
3802    let mut update = proto::UpdateContacts::default();
3803
3804    for contact in contacts {
3805        match contact {
3806            db::Contact::Accepted { user_id, busy } => {
3807                update.contacts.push(contact_for_user(user_id, busy, pool));
3808            }
3809            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3810            db::Contact::Incoming { user_id } => {
3811                update
3812                    .incoming_requests
3813                    .push(proto::IncomingContactRequest {
3814                        requester_id: user_id.to_proto(),
3815                    })
3816            }
3817        }
3818    }
3819
3820    update
3821}
3822
3823fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3824    proto::Contact {
3825        user_id: user_id.to_proto(),
3826        online: pool.is_user_online(user_id),
3827        busy,
3828    }
3829}
3830
3831fn room_updated(room: &proto::Room, peer: &Peer) {
3832    broadcast(
3833        None,
3834        room.participants
3835            .iter()
3836            .filter_map(|participant| Some(participant.peer_id?.into())),
3837        |peer_id| {
3838            peer.send(
3839                peer_id,
3840                proto::RoomUpdated {
3841                    room: Some(room.clone()),
3842                },
3843            )
3844        },
3845    );
3846}
3847
3848fn channel_updated(
3849    channel: &db::channel::Model,
3850    room: &proto::Room,
3851    peer: &Peer,
3852    pool: &ConnectionPool,
3853) {
3854    let participants = room
3855        .participants
3856        .iter()
3857        .map(|p| p.user_id)
3858        .collect::<Vec<_>>();
3859
3860    broadcast(
3861        None,
3862        pool.channel_connection_ids(channel.root_id())
3863            .filter_map(|(channel_id, role)| {
3864                role.can_see_channel(channel.visibility)
3865                    .then_some(channel_id)
3866            }),
3867        |peer_id| {
3868            peer.send(
3869                peer_id,
3870                proto::UpdateChannels {
3871                    channel_participants: vec![proto::ChannelParticipants {
3872                        channel_id: channel.id.to_proto(),
3873                        participant_user_ids: participants.clone(),
3874                    }],
3875                    ..Default::default()
3876                },
3877            )
3878        },
3879    );
3880}
3881
3882async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3883    let db = session.db().await;
3884
3885    let contacts = db.get_contacts(user_id).await?;
3886    let busy = db.is_user_busy(user_id).await?;
3887
3888    let pool = session.connection_pool().await;
3889    let updated_contact = contact_for_user(user_id, busy, &pool);
3890    for contact in contacts {
3891        if let db::Contact::Accepted {
3892            user_id: contact_user_id,
3893            ..
3894        } = contact
3895        {
3896            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3897                session
3898                    .peer
3899                    .send(
3900                        contact_conn_id,
3901                        proto::UpdateContacts {
3902                            contacts: vec![updated_contact.clone()],
3903                            remove_contacts: Default::default(),
3904                            incoming_requests: Default::default(),
3905                            remove_incoming_requests: Default::default(),
3906                            outgoing_requests: Default::default(),
3907                            remove_outgoing_requests: Default::default(),
3908                        },
3909                    )
3910                    .trace_err();
3911            }
3912        }
3913    }
3914    Ok(())
3915}
3916
3917async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3918    let mut contacts_to_update = HashSet::default();
3919
3920    let room_id;
3921    let canceled_calls_to_user_ids;
3922    let livekit_room;
3923    let delete_livekit_room;
3924    let room;
3925    let channel;
3926
3927    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3928        contacts_to_update.insert(session.user_id());
3929
3930        for project in left_room.left_projects.values() {
3931            project_left(project, session);
3932        }
3933
3934        room_id = RoomId::from_proto(left_room.room.id);
3935        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3936        livekit_room = mem::take(&mut left_room.room.livekit_room);
3937        delete_livekit_room = left_room.deleted;
3938        room = mem::take(&mut left_room.room);
3939        channel = mem::take(&mut left_room.channel);
3940
3941        room_updated(&room, &session.peer);
3942    } else {
3943        return Ok(());
3944    }
3945
3946    if let Some(channel) = channel {
3947        channel_updated(
3948            &channel,
3949            &room,
3950            &session.peer,
3951            &*session.connection_pool().await,
3952        );
3953    }
3954
3955    {
3956        let pool = session.connection_pool().await;
3957        for canceled_user_id in canceled_calls_to_user_ids {
3958            for connection_id in pool.user_connection_ids(canceled_user_id) {
3959                session
3960                    .peer
3961                    .send(
3962                        connection_id,
3963                        proto::CallCanceled {
3964                            room_id: room_id.to_proto(),
3965                        },
3966                    )
3967                    .trace_err();
3968            }
3969            contacts_to_update.insert(canceled_user_id);
3970        }
3971    }
3972
3973    for contact_user_id in contacts_to_update {
3974        update_user_contacts(contact_user_id, session).await?;
3975    }
3976
3977    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3978        live_kit
3979            .remove_participant(livekit_room.clone(), session.user_id().to_string())
3980            .await
3981            .trace_err();
3982
3983        if delete_livekit_room {
3984            live_kit.delete_room(livekit_room).await.trace_err();
3985        }
3986    }
3987
3988    Ok(())
3989}
3990
3991async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3992    let left_channel_buffers = session
3993        .db()
3994        .await
3995        .leave_channel_buffers(session.connection_id)
3996        .await?;
3997
3998    for left_buffer in left_channel_buffers {
3999        channel_buffer_updated(
4000            session.connection_id,
4001            left_buffer.connections,
4002            &proto::UpdateChannelBufferCollaborators {
4003                channel_id: left_buffer.channel_id.to_proto(),
4004                collaborators: left_buffer.collaborators,
4005            },
4006            &session.peer,
4007        );
4008    }
4009
4010    Ok(())
4011}
4012
4013fn project_left(project: &db::LeftProject, session: &Session) {
4014    for connection_id in &project.connection_ids {
4015        if project.should_unshare {
4016            session
4017                .peer
4018                .send(
4019                    *connection_id,
4020                    proto::UnshareProject {
4021                        project_id: project.id.to_proto(),
4022                    },
4023                )
4024                .trace_err();
4025        } else {
4026            session
4027                .peer
4028                .send(
4029                    *connection_id,
4030                    proto::RemoveProjectCollaborator {
4031                        project_id: project.id.to_proto(),
4032                        peer_id: Some(session.connection_id.into()),
4033                    },
4034                )
4035                .trace_err();
4036        }
4037    }
4038}
4039
4040async fn share_agent_thread(
4041    request: proto::ShareAgentThread,
4042    response: Response<proto::ShareAgentThread>,
4043    session: MessageContext,
4044) -> Result<()> {
4045    let user_id = session.user_id();
4046
4047    let share_id = SharedThreadId::from_proto(request.session_id.clone())
4048        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4049
4050    session
4051        .db()
4052        .await
4053        .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4054        .await?;
4055
4056    response.send(proto::Ack {})?;
4057
4058    Ok(())
4059}
4060
4061async fn get_shared_agent_thread(
4062    request: proto::GetSharedAgentThread,
4063    response: Response<proto::GetSharedAgentThread>,
4064    session: MessageContext,
4065) -> Result<()> {
4066    let share_id = SharedThreadId::from_proto(request.session_id)
4067        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4068
4069    let result = session.db().await.get_shared_thread(share_id).await?;
4070
4071    match result {
4072        Some((thread, username)) => {
4073            response.send(proto::GetSharedAgentThreadResponse {
4074                title: thread.title,
4075                thread_data: thread.data,
4076                sharer_username: username,
4077                created_at: thread.created_at.and_utc().to_rfc3339(),
4078            })?;
4079        }
4080        None => {
4081            return Err(anyhow!("Shared thread not found").into());
4082        }
4083    }
4084
4085    Ok(())
4086}
4087
4088pub trait ResultExt {
4089    type Ok;
4090
4091    fn trace_err(self) -> Option<Self::Ok>;
4092}
4093
4094impl<T, E> ResultExt for Result<T, E>
4095where
4096    E: std::fmt::Debug,
4097{
4098    type Ok = T;
4099
4100    #[track_caller]
4101    fn trace_err(self) -> Option<T> {
4102        match self {
4103            Ok(value) => Some(value),
4104            Err(error) => {
4105                tracing::error!("{:?}", error);
4106                None
4107            }
4108        }
4109    }
4110}