rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
  10    },
  11    executor::Executor,
  12};
  13use anyhow::{Context as _, anyhow, bail};
  14use async_tungstenite::tungstenite::{
  15    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  16};
  17use axum::headers::UserAgent;
  18use axum::{
  19    Extension, Router, TypedHeader,
  20    body::Body,
  21    extract::{
  22        ConnectInfo, WebSocketUpgrade,
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use core::fmt::{self, Debug, Formatter};
  34use futures::TryFutureExt as _;
  35use rpc::proto::split_repository_update;
  36use tracing::Span;
  37use util::paths::PathStyle;
  38
  39use futures::{
  40    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  41    stream::FuturesUnordered,
  42};
  43use prometheus::{IntGauge, register_int_gauge};
  44use rpc::{
  45    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  46    proto::{
  47        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  48        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  49    },
  50};
  51use semver::Version;
  52use serde::{Serialize, Serializer};
  53use std::{
  54    any::TypeId,
  55    future::Future,
  56    marker::PhantomData,
  57    mem,
  58    net::SocketAddr,
  59    ops::{Deref, DerefMut},
  60    rc::Rc,
  61    sync::{
  62        Arc, OnceLock,
  63        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  64    },
  65    time::{Duration, Instant},
  66};
  67use tokio::sync::{Semaphore, watch};
  68use tower::ServiceBuilder;
  69use tracing::{
  70    Instrument,
  71    field::{self},
  72    info_span, instrument,
  73};
  74
  75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  76
  77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  79
  80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  82
  83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  84
  85const TOTAL_DURATION_MS: &str = "total_duration_ms";
  86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  88const HOST_WAITING_MS: &str = "host_waiting_ms";
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  92
  93pub struct ConnectionGuard;
  94
  95impl ConnectionGuard {
  96    pub fn try_acquire() -> Result<Self, ()> {
  97        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  98        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
  99            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 100            tracing::error!(
 101                "too many concurrent connections: {}",
 102                current_connections + 1
 103            );
 104            return Err(());
 105        }
 106        Ok(ConnectionGuard)
 107    }
 108}
 109
 110impl Drop for ConnectionGuard {
 111    fn drop(&mut self) {
 112        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 113    }
 114}
 115
 116struct Response<R> {
 117    peer: Arc<Peer>,
 118    receipt: Receipt<R>,
 119    responded: Arc<AtomicBool>,
 120}
 121
 122impl<R: RequestMessage> Response<R> {
 123    fn send(self, payload: R::Response) -> Result<()> {
 124        self.responded.store(true, SeqCst);
 125        self.peer.respond(self.receipt, payload)?;
 126        Ok(())
 127    }
 128}
 129
 130#[derive(Clone, Debug)]
 131pub enum Principal {
 132    User(User),
 133    Impersonated { user: User, admin: User },
 134}
 135
 136impl Principal {
 137    fn update_span(&self, span: &tracing::Span) {
 138        match &self {
 139            Principal::User(user) => {
 140                span.record("user_id", user.id.0);
 141                span.record("login", &user.github_login);
 142            }
 143            Principal::Impersonated { user, admin } => {
 144                span.record("user_id", user.id.0);
 145                span.record("login", &user.github_login);
 146                span.record("impersonator", &admin.github_login);
 147            }
 148        }
 149    }
 150}
 151
 152#[derive(Clone)]
 153struct MessageContext {
 154    session: Session,
 155    span: tracing::Span,
 156}
 157
 158impl Deref for MessageContext {
 159    type Target = Session;
 160
 161    fn deref(&self) -> &Self::Target {
 162        &self.session
 163    }
 164}
 165
 166impl MessageContext {
 167    pub fn forward_request<T: RequestMessage>(
 168        &self,
 169        receiver_id: ConnectionId,
 170        request: T,
 171    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 172        let request_start_time = Instant::now();
 173        let span = self.span.clone();
 174        tracing::info!("start forwarding request");
 175        self.peer
 176            .forward_request(self.connection_id, receiver_id, request)
 177            .inspect(move |_| {
 178                span.record(
 179                    HOST_WAITING_MS,
 180                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 181                );
 182            })
 183            .inspect_err(|_| tracing::error!("error forwarding request"))
 184            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 185    }
 186}
 187
 188#[derive(Clone)]
 189struct Session {
 190    principal: Principal,
 191    connection_id: ConnectionId,
 192    db: Arc<tokio::sync::Mutex<DbHandle>>,
 193    peer: Arc<Peer>,
 194    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 195    app_state: Arc<AppState>,
 196    /// The GeoIP country code for the user.
 197    #[allow(unused)]
 198    geoip_country_code: Option<String>,
 199    #[allow(unused)]
 200    system_id: Option<String>,
 201    _executor: Executor,
 202}
 203
 204impl Session {
 205    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 206        #[cfg(test)]
 207        tokio::task::yield_now().await;
 208        let guard = self.db.lock().await;
 209        #[cfg(test)]
 210        tokio::task::yield_now().await;
 211        guard
 212    }
 213
 214    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 215        #[cfg(test)]
 216        tokio::task::yield_now().await;
 217        let guard = self.connection_pool.lock();
 218        ConnectionPoolGuard {
 219            guard,
 220            _not_send: PhantomData,
 221        }
 222    }
 223
 224    #[expect(dead_code)]
 225    fn is_staff(&self) -> bool {
 226        match &self.principal {
 227            Principal::User(user) => user.admin,
 228            Principal::Impersonated { .. } => true,
 229        }
 230    }
 231
 232    fn user_id(&self) -> UserId {
 233        match &self.principal {
 234            Principal::User(user) => user.id,
 235            Principal::Impersonated { user, .. } => user.id,
 236        }
 237    }
 238}
 239
 240impl Debug for Session {
 241    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 242        let mut result = f.debug_struct("Session");
 243        match &self.principal {
 244            Principal::User(user) => {
 245                result.field("user", &user.github_login);
 246            }
 247            Principal::Impersonated { user, admin } => {
 248                result.field("user", &user.github_login);
 249                result.field("impersonator", &admin.github_login);
 250            }
 251        }
 252        result.field("connection_id", &self.connection_id).finish()
 253    }
 254}
 255
 256struct DbHandle(Arc<Database>);
 257
 258impl Deref for DbHandle {
 259    type Target = Database;
 260
 261    fn deref(&self) -> &Self::Target {
 262        self.0.as_ref()
 263    }
 264}
 265
 266pub struct Server {
 267    id: parking_lot::Mutex<ServerId>,
 268    peer: Arc<Peer>,
 269    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 270    app_state: Arc<AppState>,
 271    handlers: HashMap<TypeId, MessageHandler>,
 272    teardown: watch::Sender<bool>,
 273}
 274
 275pub(crate) struct ConnectionPoolGuard<'a> {
 276    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 277    _not_send: PhantomData<Rc<()>>,
 278}
 279
 280#[derive(Serialize)]
 281pub struct ServerSnapshot<'a> {
 282    peer: &'a Peer,
 283    #[serde(serialize_with = "serialize_deref")]
 284    connection_pool: ConnectionPoolGuard<'a>,
 285}
 286
 287pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 288where
 289    S: Serializer,
 290    T: Deref<Target = U>,
 291    U: Serialize,
 292{
 293    Serialize::serialize(value.deref(), serializer)
 294}
 295
 296impl Server {
 297    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 298        let mut server = Self {
 299            id: parking_lot::Mutex::new(id),
 300            peer: Peer::new(id.0 as u32),
 301            app_state,
 302            connection_pool: Default::default(),
 303            handlers: Default::default(),
 304            teardown: watch::channel(false).0,
 305        };
 306
 307        server
 308            .add_request_handler(ping)
 309            .add_request_handler(create_room)
 310            .add_request_handler(join_room)
 311            .add_request_handler(rejoin_room)
 312            .add_request_handler(leave_room)
 313            .add_request_handler(set_room_participant_role)
 314            .add_request_handler(call)
 315            .add_request_handler(cancel_call)
 316            .add_message_handler(decline_call)
 317            .add_request_handler(update_participant_location)
 318            .add_request_handler(share_project)
 319            .add_message_handler(unshare_project)
 320            .add_request_handler(join_project)
 321            .add_message_handler(leave_project)
 322            .add_request_handler(update_project)
 323            .add_request_handler(update_worktree)
 324            .add_request_handler(update_repository)
 325            .add_request_handler(remove_repository)
 326            .add_message_handler(start_language_server)
 327            .add_message_handler(update_language_server)
 328            .add_message_handler(update_diagnostic_summary)
 329            .add_message_handler(update_worktree_settings)
 330            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 331            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 332            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 333            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 334            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 335            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 336            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 337            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 338            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 339            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 340            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 341            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 342            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 343            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 344            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 345            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 346            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 347            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 348            .add_request_handler(
 349                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 350            )
 351            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 352            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 353            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 354            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 355            .add_request_handler(
 356                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 357            )
 358            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 359            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 360            .add_request_handler(
 361                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 362            )
 363            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 364            .add_request_handler(
 365                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 366            )
 367            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 368            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 369            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 370            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 371            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 372            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 373            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 374            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 375            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 376            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 377            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 378            .add_request_handler(
 379                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 380            )
 381            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 382            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 383            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 384            .add_request_handler(lsp_query)
 385            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 386            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 387            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 388            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 389            .add_message_handler(create_buffer_for_peer)
 390            .add_message_handler(create_image_for_peer)
 391            .add_request_handler(update_buffer)
 392            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 393            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 394            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 395            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 396            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 397            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 398            .add_message_handler(
 399                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 400            )
 401            .add_request_handler(get_users)
 402            .add_request_handler(fuzzy_search_users)
 403            .add_request_handler(request_contact)
 404            .add_request_handler(remove_contact)
 405            .add_request_handler(respond_to_contact_request)
 406            .add_message_handler(subscribe_to_channels)
 407            .add_request_handler(create_channel)
 408            .add_request_handler(delete_channel)
 409            .add_request_handler(invite_channel_member)
 410            .add_request_handler(remove_channel_member)
 411            .add_request_handler(set_channel_member_role)
 412            .add_request_handler(set_channel_visibility)
 413            .add_request_handler(rename_channel)
 414            .add_request_handler(join_channel_buffer)
 415            .add_request_handler(leave_channel_buffer)
 416            .add_message_handler(update_channel_buffer)
 417            .add_request_handler(rejoin_channel_buffers)
 418            .add_request_handler(get_channel_members)
 419            .add_request_handler(respond_to_channel_invite)
 420            .add_request_handler(join_channel)
 421            .add_request_handler(join_channel_chat)
 422            .add_message_handler(leave_channel_chat)
 423            .add_request_handler(send_channel_message)
 424            .add_request_handler(remove_channel_message)
 425            .add_request_handler(update_channel_message)
 426            .add_request_handler(get_channel_messages)
 427            .add_request_handler(get_channel_messages_by_id)
 428            .add_request_handler(get_notifications)
 429            .add_request_handler(mark_notification_as_read)
 430            .add_request_handler(move_channel)
 431            .add_request_handler(reorder_channel)
 432            .add_request_handler(follow)
 433            .add_message_handler(unfollow)
 434            .add_message_handler(update_followers)
 435            .add_message_handler(acknowledge_channel_message)
 436            .add_message_handler(acknowledge_buffer_version)
 437            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 438            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 439            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 440            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 441            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 442            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 443            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 444            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 445            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 446            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 447            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 448            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 449            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 450            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 451            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 452            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 453            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 454            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 455            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 456            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 457            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 458            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 459            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 460            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 461            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 462            .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
 463            .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
 464            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 465            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 466            .add_message_handler(update_context)
 467            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 468            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
 469
 470        Arc::new(server)
 471    }
 472
 473    pub async fn start(&self) -> Result<()> {
 474        let server_id = *self.id.lock();
 475        let app_state = self.app_state.clone();
 476        let peer = self.peer.clone();
 477        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 478        let pool = self.connection_pool.clone();
 479        let livekit_client = self.app_state.livekit_client.clone();
 480
 481        let span = info_span!("start server");
 482        self.app_state.executor.spawn_detached(
 483            async move {
 484                tracing::info!("waiting for cleanup timeout");
 485                timeout.await;
 486                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 487
 488                app_state
 489                    .db
 490                    .delete_stale_channel_chat_participants(
 491                        &app_state.config.zed_environment,
 492                        server_id,
 493                    )
 494                    .await
 495                    .trace_err();
 496
 497                if let Some((room_ids, channel_ids)) = app_state
 498                    .db
 499                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 500                    .await
 501                    .trace_err()
 502                {
 503                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 504                    tracing::info!(
 505                        stale_channel_buffer_count = channel_ids.len(),
 506                        "retrieved stale channel buffers"
 507                    );
 508
 509                    for channel_id in channel_ids {
 510                        if let Some(refreshed_channel_buffer) = app_state
 511                            .db
 512                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 513                            .await
 514                            .trace_err()
 515                        {
 516                            for connection_id in refreshed_channel_buffer.connection_ids {
 517                                peer.send(
 518                                    connection_id,
 519                                    proto::UpdateChannelBufferCollaborators {
 520                                        channel_id: channel_id.to_proto(),
 521                                        collaborators: refreshed_channel_buffer
 522                                            .collaborators
 523                                            .clone(),
 524                                    },
 525                                )
 526                                .trace_err();
 527                            }
 528                        }
 529                    }
 530
 531                    for room_id in room_ids {
 532                        let mut contacts_to_update = HashSet::default();
 533                        let mut canceled_calls_to_user_ids = Vec::new();
 534                        let mut livekit_room = String::new();
 535                        let mut delete_livekit_room = false;
 536
 537                        if let Some(mut refreshed_room) = app_state
 538                            .db
 539                            .clear_stale_room_participants(room_id, server_id)
 540                            .await
 541                            .trace_err()
 542                        {
 543                            tracing::info!(
 544                                room_id = room_id.0,
 545                                new_participant_count = refreshed_room.room.participants.len(),
 546                                "refreshed room"
 547                            );
 548                            room_updated(&refreshed_room.room, &peer);
 549                            if let Some(channel) = refreshed_room.channel.as_ref() {
 550                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 551                            }
 552                            contacts_to_update
 553                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 554                            contacts_to_update
 555                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 556                            canceled_calls_to_user_ids =
 557                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 558                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 559                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 560                        }
 561
 562                        {
 563                            let pool = pool.lock();
 564                            for canceled_user_id in canceled_calls_to_user_ids {
 565                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 566                                    peer.send(
 567                                        connection_id,
 568                                        proto::CallCanceled {
 569                                            room_id: room_id.to_proto(),
 570                                        },
 571                                    )
 572                                    .trace_err();
 573                                }
 574                            }
 575                        }
 576
 577                        for user_id in contacts_to_update {
 578                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 579                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 580                            if let Some((busy, contacts)) = busy.zip(contacts) {
 581                                let pool = pool.lock();
 582                                let updated_contact = contact_for_user(user_id, busy, &pool);
 583                                for contact in contacts {
 584                                    if let db::Contact::Accepted {
 585                                        user_id: contact_user_id,
 586                                        ..
 587                                    } = contact
 588                                    {
 589                                        for contact_conn_id in
 590                                            pool.user_connection_ids(contact_user_id)
 591                                        {
 592                                            peer.send(
 593                                                contact_conn_id,
 594                                                proto::UpdateContacts {
 595                                                    contacts: vec![updated_contact.clone()],
 596                                                    remove_contacts: Default::default(),
 597                                                    incoming_requests: Default::default(),
 598                                                    remove_incoming_requests: Default::default(),
 599                                                    outgoing_requests: Default::default(),
 600                                                    remove_outgoing_requests: Default::default(),
 601                                                },
 602                                            )
 603                                            .trace_err();
 604                                        }
 605                                    }
 606                                }
 607                            }
 608                        }
 609
 610                        if let Some(live_kit) = livekit_client.as_ref()
 611                            && delete_livekit_room
 612                        {
 613                            live_kit.delete_room(livekit_room).await.trace_err();
 614                        }
 615                    }
 616                }
 617
 618                app_state
 619                    .db
 620                    .delete_stale_channel_chat_participants(
 621                        &app_state.config.zed_environment,
 622                        server_id,
 623                    )
 624                    .await
 625                    .trace_err();
 626
 627                app_state
 628                    .db
 629                    .clear_old_worktree_entries(server_id)
 630                    .await
 631                    .trace_err();
 632
 633                app_state
 634                    .db
 635                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 636                    .await
 637                    .trace_err();
 638            }
 639            .instrument(span),
 640        );
 641        Ok(())
 642    }
 643
 644    pub fn teardown(&self) {
 645        self.peer.teardown();
 646        self.connection_pool.lock().reset();
 647        let _ = self.teardown.send(true);
 648    }
 649
 650    #[cfg(test)]
 651    pub fn reset(&self, id: ServerId) {
 652        self.teardown();
 653        *self.id.lock() = id;
 654        self.peer.reset(id.0 as u32);
 655        let _ = self.teardown.send(false);
 656    }
 657
 658    #[cfg(test)]
 659    pub fn id(&self) -> ServerId {
 660        *self.id.lock()
 661    }
 662
 663    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 664    where
 665        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 666        Fut: 'static + Send + Future<Output = Result<()>>,
 667        M: EnvelopedMessage,
 668    {
 669        let prev_handler = self.handlers.insert(
 670            TypeId::of::<M>(),
 671            Box::new(move |envelope, session, span| {
 672                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 673                let received_at = envelope.received_at;
 674                tracing::info!("message received");
 675                let start_time = Instant::now();
 676                let future = (handler)(
 677                    *envelope,
 678                    MessageContext {
 679                        session,
 680                        span: span.clone(),
 681                    },
 682                );
 683                async move {
 684                    let result = future.await;
 685                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 686                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 687                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 688                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 689                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 690                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 691                    match result {
 692                        Err(error) => {
 693                            tracing::error!(?error, "error handling message")
 694                        }
 695                        Ok(()) => tracing::info!("finished handling message"),
 696                    }
 697                }
 698                .boxed()
 699            }),
 700        );
 701        if prev_handler.is_some() {
 702            panic!("registered a handler for the same message twice");
 703        }
 704        self
 705    }
 706
 707    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 708    where
 709        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 710        Fut: 'static + Send + Future<Output = Result<()>>,
 711        M: EnvelopedMessage,
 712    {
 713        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 714        self
 715    }
 716
 717    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 718    where
 719        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 720        Fut: Send + Future<Output = Result<()>>,
 721        M: RequestMessage,
 722    {
 723        let handler = Arc::new(handler);
 724        self.add_handler(move |envelope, session| {
 725            let receipt = envelope.receipt();
 726            let handler = handler.clone();
 727            async move {
 728                let peer = session.peer.clone();
 729                let responded = Arc::new(AtomicBool::default());
 730                let response = Response {
 731                    peer: peer.clone(),
 732                    responded: responded.clone(),
 733                    receipt,
 734                };
 735                match (handler)(envelope.payload, response, session).await {
 736                    Ok(()) => {
 737                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 738                            Ok(())
 739                        } else {
 740                            Err(anyhow!("handler did not send a response"))?
 741                        }
 742                    }
 743                    Err(error) => {
 744                        let proto_err = match &error {
 745                            Error::Internal(err) => err.to_proto(),
 746                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 747                        };
 748                        peer.respond_with_error(receipt, proto_err)?;
 749                        Err(error)
 750                    }
 751                }
 752            }
 753        })
 754    }
 755
 756    pub fn handle_connection(
 757        self: &Arc<Self>,
 758        connection: Connection,
 759        address: String,
 760        principal: Principal,
 761        zed_version: ZedVersion,
 762        release_channel: Option<String>,
 763        user_agent: Option<String>,
 764        geoip_country_code: Option<String>,
 765        system_id: Option<String>,
 766        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 767        executor: Executor,
 768        connection_guard: Option<ConnectionGuard>,
 769    ) -> impl Future<Output = ()> + use<> {
 770        let this = self.clone();
 771        let span = info_span!("handle connection", %address,
 772            connection_id=field::Empty,
 773            user_id=field::Empty,
 774            login=field::Empty,
 775            impersonator=field::Empty,
 776            user_agent=field::Empty,
 777            geoip_country_code=field::Empty,
 778            release_channel=field::Empty,
 779        );
 780        principal.update_span(&span);
 781        if let Some(user_agent) = user_agent {
 782            span.record("user_agent", user_agent);
 783        }
 784        if let Some(release_channel) = release_channel {
 785            span.record("release_channel", release_channel);
 786        }
 787
 788        if let Some(country_code) = geoip_country_code.as_ref() {
 789            span.record("geoip_country_code", country_code);
 790        }
 791
 792        let mut teardown = self.teardown.subscribe();
 793        async move {
 794            if *teardown.borrow() {
 795                tracing::error!("server is tearing down");
 796                return;
 797            }
 798
 799            let (connection_id, handle_io, mut incoming_rx) =
 800                this.peer.add_connection(connection, {
 801                    let executor = executor.clone();
 802                    move |duration| executor.sleep(duration)
 803                });
 804            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 805
 806            tracing::info!("connection opened");
 807
 808            let session = Session {
 809                principal: principal.clone(),
 810                connection_id,
 811                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 812                peer: this.peer.clone(),
 813                connection_pool: this.connection_pool.clone(),
 814                app_state: this.app_state.clone(),
 815                geoip_country_code,
 816                system_id,
 817                _executor: executor.clone(),
 818            };
 819
 820            if let Err(error) = this
 821                .send_initial_client_update(
 822                    connection_id,
 823                    zed_version,
 824                    send_connection_id,
 825                    &session,
 826                )
 827                .await
 828            {
 829                tracing::error!(?error, "failed to send initial client update");
 830                return;
 831            }
 832            drop(connection_guard);
 833
 834            let handle_io = handle_io.fuse();
 835            futures::pin_mut!(handle_io);
 836
 837            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 838            // This prevents deadlocks when e.g., client A performs a request to client B and
 839            // client B performs a request to client A. If both clients stop processing further
 840            // messages until their respective request completes, they won't have a chance to
 841            // respond to the other client's request and cause a deadlock.
 842            //
 843            // This arrangement ensures we will attempt to process earlier messages first, but fall
 844            // back to processing messages arrived later in the spirit of making progress.
 845            const MAX_CONCURRENT_HANDLERS: usize = 256;
 846            let mut foreground_message_handlers = FuturesUnordered::new();
 847            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 848            let get_concurrent_handlers = {
 849                let concurrent_handlers = concurrent_handlers.clone();
 850                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 851            };
 852            loop {
 853                let next_message = async {
 854                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 855                    let message = incoming_rx.next().await;
 856                    // Cache the concurrent_handlers here, so that we know what the
 857                    // queue looks like as each handler starts
 858                    (permit, message, get_concurrent_handlers())
 859                }
 860                .fuse();
 861                futures::pin_mut!(next_message);
 862                futures::select_biased! {
 863                    _ = teardown.changed().fuse() => return,
 864                    result = handle_io => {
 865                        if let Err(error) = result {
 866                            tracing::error!(?error, "error handling I/O");
 867                        }
 868                        break;
 869                    }
 870                    _ = foreground_message_handlers.next() => {}
 871                    next_message = next_message => {
 872                        let (permit, message, concurrent_handlers) = next_message;
 873                        if let Some(message) = message {
 874                            let type_name = message.payload_type_name();
 875                            // note: we copy all the fields from the parent span so we can query them in the logs.
 876                            // (https://github.com/tokio-rs/tracing/issues/2670).
 877                            let span = tracing::info_span!("receive message",
 878                                %connection_id,
 879                                %address,
 880                                type_name,
 881                                concurrent_handlers,
 882                                user_id=field::Empty,
 883                                login=field::Empty,
 884                                impersonator=field::Empty,
 885                                lsp_query_request=field::Empty,
 886                                release_channel=field::Empty,
 887                                { TOTAL_DURATION_MS }=field::Empty,
 888                                { PROCESSING_DURATION_MS }=field::Empty,
 889                                { QUEUE_DURATION_MS }=field::Empty,
 890                                { HOST_WAITING_MS }=field::Empty
 891                            );
 892                            principal.update_span(&span);
 893                            let span_enter = span.enter();
 894                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 895                                let is_background = message.is_background();
 896                                let handle_message = (handler)(message, session.clone(), span.clone());
 897                                drop(span_enter);
 898
 899                                let handle_message = async move {
 900                                    handle_message.await;
 901                                    drop(permit);
 902                                }.instrument(span);
 903                                if is_background {
 904                                    executor.spawn_detached(handle_message);
 905                                } else {
 906                                    foreground_message_handlers.push(handle_message);
 907                                }
 908                            } else {
 909                                tracing::error!("no message handler");
 910                            }
 911                        } else {
 912                            tracing::info!("connection closed");
 913                            break;
 914                        }
 915                    }
 916                }
 917            }
 918
 919            drop(foreground_message_handlers);
 920            let concurrent_handlers = get_concurrent_handlers();
 921            tracing::info!(concurrent_handlers, "signing out");
 922            if let Err(error) = connection_lost(session, teardown, executor).await {
 923                tracing::error!(?error, "error signing out");
 924            }
 925        }
 926        .instrument(span)
 927    }
 928
 929    async fn send_initial_client_update(
 930        &self,
 931        connection_id: ConnectionId,
 932        zed_version: ZedVersion,
 933        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 934        session: &Session,
 935    ) -> Result<()> {
 936        self.peer.send(
 937            connection_id,
 938            proto::Hello {
 939                peer_id: Some(connection_id.into()),
 940            },
 941        )?;
 942        tracing::info!("sent hello message");
 943        if let Some(send_connection_id) = send_connection_id.take() {
 944            let _ = send_connection_id.send(connection_id);
 945        }
 946
 947        match &session.principal {
 948            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 949                if !user.connected_once {
 950                    self.peer.send(connection_id, proto::ShowContacts {})?;
 951                    self.app_state
 952                        .db
 953                        .set_user_connected_once(user.id, true)
 954                        .await?;
 955                }
 956
 957                let contacts = self.app_state.db.get_contacts(user.id).await?;
 958
 959                {
 960                    let mut pool = self.connection_pool.lock();
 961                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 962                    self.peer.send(
 963                        connection_id,
 964                        build_initial_contacts_update(contacts, &pool),
 965                    )?;
 966                }
 967
 968                if should_auto_subscribe_to_channels(&zed_version) {
 969                    subscribe_user_to_channels(user.id, session).await?;
 970                }
 971
 972                if let Some(incoming_call) =
 973                    self.app_state.db.incoming_call_for_user(user.id).await?
 974                {
 975                    self.peer.send(connection_id, incoming_call)?;
 976                }
 977
 978                update_user_contacts(user.id, session).await?;
 979            }
 980        }
 981
 982        Ok(())
 983    }
 984
 985    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
 986        ServerSnapshot {
 987            connection_pool: ConnectionPoolGuard {
 988                guard: self.connection_pool.lock(),
 989                _not_send: PhantomData,
 990            },
 991            peer: &self.peer,
 992        }
 993    }
 994}
 995
 996impl Deref for ConnectionPoolGuard<'_> {
 997    type Target = ConnectionPool;
 998
 999    fn deref(&self) -> &Self::Target {
1000        &self.guard
1001    }
1002}
1003
1004impl DerefMut for ConnectionPoolGuard<'_> {
1005    fn deref_mut(&mut self) -> &mut Self::Target {
1006        &mut self.guard
1007    }
1008}
1009
1010impl Drop for ConnectionPoolGuard<'_> {
1011    fn drop(&mut self) {
1012        #[cfg(test)]
1013        self.check_invariants();
1014    }
1015}
1016
1017fn broadcast<F>(
1018    sender_id: Option<ConnectionId>,
1019    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1020    mut f: F,
1021) where
1022    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1023{
1024    for receiver_id in receiver_ids {
1025        if Some(receiver_id) != sender_id
1026            && let Err(error) = f(receiver_id)
1027        {
1028            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1029        }
1030    }
1031}
1032
1033pub struct ProtocolVersion(u32);
1034
1035impl Header for ProtocolVersion {
1036    fn name() -> &'static HeaderName {
1037        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1038        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1039    }
1040
1041    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1042    where
1043        Self: Sized,
1044        I: Iterator<Item = &'i axum::http::HeaderValue>,
1045    {
1046        let version = values
1047            .next()
1048            .ok_or_else(axum::headers::Error::invalid)?
1049            .to_str()
1050            .map_err(|_| axum::headers::Error::invalid())?
1051            .parse()
1052            .map_err(|_| axum::headers::Error::invalid())?;
1053        Ok(Self(version))
1054    }
1055
1056    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1057        values.extend([self.0.to_string().parse().unwrap()]);
1058    }
1059}
1060
1061pub struct AppVersionHeader(Version);
1062impl Header for AppVersionHeader {
1063    fn name() -> &'static HeaderName {
1064        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1065        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1066    }
1067
1068    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1069    where
1070        Self: Sized,
1071        I: Iterator<Item = &'i axum::http::HeaderValue>,
1072    {
1073        let version = values
1074            .next()
1075            .ok_or_else(axum::headers::Error::invalid)?
1076            .to_str()
1077            .map_err(|_| axum::headers::Error::invalid())?
1078            .parse()
1079            .map_err(|_| axum::headers::Error::invalid())?;
1080        Ok(Self(version))
1081    }
1082
1083    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1084        values.extend([self.0.to_string().parse().unwrap()]);
1085    }
1086}
1087
1088#[derive(Debug)]
1089pub struct ReleaseChannelHeader(String);
1090
1091impl Header for ReleaseChannelHeader {
1092    fn name() -> &'static HeaderName {
1093        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1094        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1095    }
1096
1097    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1098    where
1099        Self: Sized,
1100        I: Iterator<Item = &'i axum::http::HeaderValue>,
1101    {
1102        Ok(Self(
1103            values
1104                .next()
1105                .ok_or_else(axum::headers::Error::invalid)?
1106                .to_str()
1107                .map_err(|_| axum::headers::Error::invalid())?
1108                .to_owned(),
1109        ))
1110    }
1111
1112    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1113        values.extend([self.0.parse().unwrap()]);
1114    }
1115}
1116
1117pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1118    Router::new()
1119        .route("/rpc", get(handle_websocket_request))
1120        .layer(
1121            ServiceBuilder::new()
1122                .layer(Extension(server.app_state.clone()))
1123                .layer(middleware::from_fn(auth::validate_header)),
1124        )
1125        .route("/metrics", get(handle_metrics))
1126        .layer(Extension(server))
1127}
1128
1129pub async fn handle_websocket_request(
1130    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1131    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1132    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1133    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1134    Extension(server): Extension<Arc<Server>>,
1135    Extension(principal): Extension<Principal>,
1136    user_agent: Option<TypedHeader<UserAgent>>,
1137    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1138    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1139    ws: WebSocketUpgrade,
1140) -> axum::response::Response {
1141    if protocol_version != rpc::PROTOCOL_VERSION {
1142        return (
1143            StatusCode::UPGRADE_REQUIRED,
1144            "client must be upgraded".to_string(),
1145        )
1146            .into_response();
1147    }
1148
1149    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1150        return (
1151            StatusCode::UPGRADE_REQUIRED,
1152            "no version header found".to_string(),
1153        )
1154            .into_response();
1155    };
1156
1157    let release_channel = release_channel_header.map(|header| header.0.0);
1158
1159    if !version.can_collaborate() {
1160        return (
1161            StatusCode::UPGRADE_REQUIRED,
1162            "client must be upgraded".to_string(),
1163        )
1164            .into_response();
1165    }
1166
1167    let socket_address = socket_address.to_string();
1168
1169    // Acquire connection guard before WebSocket upgrade
1170    let connection_guard = match ConnectionGuard::try_acquire() {
1171        Ok(guard) => guard,
1172        Err(()) => {
1173            return (
1174                StatusCode::SERVICE_UNAVAILABLE,
1175                "Too many concurrent connections",
1176            )
1177                .into_response();
1178        }
1179    };
1180
1181    ws.on_upgrade(move |socket| {
1182        let socket = socket
1183            .map_ok(to_tungstenite_message)
1184            .err_into()
1185            .with(|message| async move { to_axum_message(message) });
1186        let connection = Connection::new(Box::pin(socket));
1187        async move {
1188            server
1189                .handle_connection(
1190                    connection,
1191                    socket_address,
1192                    principal,
1193                    version,
1194                    release_channel,
1195                    user_agent.map(|header| header.to_string()),
1196                    country_code_header.map(|header| header.to_string()),
1197                    system_id_header.map(|header| header.to_string()),
1198                    None,
1199                    Executor::Production,
1200                    Some(connection_guard),
1201                )
1202                .await;
1203        }
1204    })
1205}
1206
1207pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1208    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1209    let connections_metric = CONNECTIONS_METRIC
1210        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1211
1212    let connections = server
1213        .connection_pool
1214        .lock()
1215        .connections()
1216        .filter(|connection| !connection.admin)
1217        .count();
1218    connections_metric.set(connections as _);
1219
1220    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1221    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1222        register_int_gauge!(
1223            "shared_projects",
1224            "number of open projects with one or more guests"
1225        )
1226        .unwrap()
1227    });
1228
1229    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1230    shared_projects_metric.set(shared_projects as _);
1231
1232    let encoder = prometheus::TextEncoder::new();
1233    let metric_families = prometheus::gather();
1234    let encoded_metrics = encoder
1235        .encode_to_string(&metric_families)
1236        .map_err(|err| anyhow!("{err}"))?;
1237    Ok(encoded_metrics)
1238}
1239
1240#[instrument(err, skip(executor))]
1241async fn connection_lost(
1242    session: Session,
1243    mut teardown: watch::Receiver<bool>,
1244    executor: Executor,
1245) -> Result<()> {
1246    session.peer.disconnect(session.connection_id);
1247    session
1248        .connection_pool()
1249        .await
1250        .remove_connection(session.connection_id)?;
1251
1252    session
1253        .db()
1254        .await
1255        .connection_lost(session.connection_id)
1256        .await
1257        .trace_err();
1258
1259    futures::select_biased! {
1260        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1261
1262            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1263            leave_room_for_session(&session, session.connection_id).await.trace_err();
1264            leave_channel_buffers_for_session(&session)
1265                .await
1266                .trace_err();
1267
1268            if !session
1269                .connection_pool()
1270                .await
1271                .is_user_online(session.user_id())
1272            {
1273                let db = session.db().await;
1274                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1275                    room_updated(&room, &session.peer);
1276                }
1277            }
1278
1279            update_user_contacts(session.user_id(), &session).await?;
1280        },
1281        _ = teardown.changed().fuse() => {}
1282    }
1283
1284    Ok(())
1285}
1286
1287/// Acknowledges a ping from a client, used to keep the connection alive.
1288async fn ping(
1289    _: proto::Ping,
1290    response: Response<proto::Ping>,
1291    _session: MessageContext,
1292) -> Result<()> {
1293    response.send(proto::Ack {})?;
1294    Ok(())
1295}
1296
1297/// Creates a new room for calling (outside of channels)
1298async fn create_room(
1299    _request: proto::CreateRoom,
1300    response: Response<proto::CreateRoom>,
1301    session: MessageContext,
1302) -> Result<()> {
1303    let livekit_room = nanoid::nanoid!(30);
1304
1305    let live_kit_connection_info = util::maybe!(async {
1306        let live_kit = session.app_state.livekit_client.as_ref();
1307        let live_kit = live_kit?;
1308        let user_id = session.user_id().to_string();
1309
1310        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1311
1312        Some(proto::LiveKitConnectionInfo {
1313            server_url: live_kit.url().into(),
1314            token,
1315            can_publish: true,
1316        })
1317    })
1318    .await;
1319
1320    let room = session
1321        .db()
1322        .await
1323        .create_room(session.user_id(), session.connection_id, &livekit_room)
1324        .await?;
1325
1326    response.send(proto::CreateRoomResponse {
1327        room: Some(room.clone()),
1328        live_kit_connection_info,
1329    })?;
1330
1331    update_user_contacts(session.user_id(), &session).await?;
1332    Ok(())
1333}
1334
1335/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1336async fn join_room(
1337    request: proto::JoinRoom,
1338    response: Response<proto::JoinRoom>,
1339    session: MessageContext,
1340) -> Result<()> {
1341    let room_id = RoomId::from_proto(request.id);
1342
1343    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1344
1345    if let Some(channel_id) = channel_id {
1346        return join_channel_internal(channel_id, Box::new(response), session).await;
1347    }
1348
1349    let joined_room = {
1350        let room = session
1351            .db()
1352            .await
1353            .join_room(room_id, session.user_id(), session.connection_id)
1354            .await?;
1355        room_updated(&room.room, &session.peer);
1356        room.into_inner()
1357    };
1358
1359    for connection_id in session
1360        .connection_pool()
1361        .await
1362        .user_connection_ids(session.user_id())
1363    {
1364        session
1365            .peer
1366            .send(
1367                connection_id,
1368                proto::CallCanceled {
1369                    room_id: room_id.to_proto(),
1370                },
1371            )
1372            .trace_err();
1373    }
1374
1375    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1376    {
1377        live_kit
1378            .room_token(
1379                &joined_room.room.livekit_room,
1380                &session.user_id().to_string(),
1381            )
1382            .trace_err()
1383            .map(|token| proto::LiveKitConnectionInfo {
1384                server_url: live_kit.url().into(),
1385                token,
1386                can_publish: true,
1387            })
1388    } else {
1389        None
1390    };
1391
1392    response.send(proto::JoinRoomResponse {
1393        room: Some(joined_room.room),
1394        channel_id: None,
1395        live_kit_connection_info,
1396    })?;
1397
1398    update_user_contacts(session.user_id(), &session).await?;
1399    Ok(())
1400}
1401
1402/// Rejoin room is used to reconnect to a room after connection errors.
1403async fn rejoin_room(
1404    request: proto::RejoinRoom,
1405    response: Response<proto::RejoinRoom>,
1406    session: MessageContext,
1407) -> Result<()> {
1408    let room;
1409    let channel;
1410    {
1411        let mut rejoined_room = session
1412            .db()
1413            .await
1414            .rejoin_room(request, session.user_id(), session.connection_id)
1415            .await?;
1416
1417        response.send(proto::RejoinRoomResponse {
1418            room: Some(rejoined_room.room.clone()),
1419            reshared_projects: rejoined_room
1420                .reshared_projects
1421                .iter()
1422                .map(|project| proto::ResharedProject {
1423                    id: project.id.to_proto(),
1424                    collaborators: project
1425                        .collaborators
1426                        .iter()
1427                        .map(|collaborator| collaborator.to_proto())
1428                        .collect(),
1429                })
1430                .collect(),
1431            rejoined_projects: rejoined_room
1432                .rejoined_projects
1433                .iter()
1434                .map(|rejoined_project| rejoined_project.to_proto())
1435                .collect(),
1436        })?;
1437        room_updated(&rejoined_room.room, &session.peer);
1438
1439        for project in &rejoined_room.reshared_projects {
1440            for collaborator in &project.collaborators {
1441                session
1442                    .peer
1443                    .send(
1444                        collaborator.connection_id,
1445                        proto::UpdateProjectCollaborator {
1446                            project_id: project.id.to_proto(),
1447                            old_peer_id: Some(project.old_connection_id.into()),
1448                            new_peer_id: Some(session.connection_id.into()),
1449                        },
1450                    )
1451                    .trace_err();
1452            }
1453
1454            broadcast(
1455                Some(session.connection_id),
1456                project
1457                    .collaborators
1458                    .iter()
1459                    .map(|collaborator| collaborator.connection_id),
1460                |connection_id| {
1461                    session.peer.forward_send(
1462                        session.connection_id,
1463                        connection_id,
1464                        proto::UpdateProject {
1465                            project_id: project.id.to_proto(),
1466                            worktrees: project.worktrees.clone(),
1467                        },
1468                    )
1469                },
1470            );
1471        }
1472
1473        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1474
1475        let rejoined_room = rejoined_room.into_inner();
1476
1477        room = rejoined_room.room;
1478        channel = rejoined_room.channel;
1479    }
1480
1481    if let Some(channel) = channel {
1482        channel_updated(
1483            &channel,
1484            &room,
1485            &session.peer,
1486            &*session.connection_pool().await,
1487        );
1488    }
1489
1490    update_user_contacts(session.user_id(), &session).await?;
1491    Ok(())
1492}
1493
1494fn notify_rejoined_projects(
1495    rejoined_projects: &mut Vec<RejoinedProject>,
1496    session: &Session,
1497) -> Result<()> {
1498    for project in rejoined_projects.iter() {
1499        for collaborator in &project.collaborators {
1500            session
1501                .peer
1502                .send(
1503                    collaborator.connection_id,
1504                    proto::UpdateProjectCollaborator {
1505                        project_id: project.id.to_proto(),
1506                        old_peer_id: Some(project.old_connection_id.into()),
1507                        new_peer_id: Some(session.connection_id.into()),
1508                    },
1509                )
1510                .trace_err();
1511        }
1512    }
1513
1514    for project in rejoined_projects {
1515        for worktree in mem::take(&mut project.worktrees) {
1516            // Stream this worktree's entries.
1517            let message = proto::UpdateWorktree {
1518                project_id: project.id.to_proto(),
1519                worktree_id: worktree.id,
1520                abs_path: worktree.abs_path.clone(),
1521                root_name: worktree.root_name,
1522                updated_entries: worktree.updated_entries,
1523                removed_entries: worktree.removed_entries,
1524                scan_id: worktree.scan_id,
1525                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1526                updated_repositories: worktree.updated_repositories,
1527                removed_repositories: worktree.removed_repositories,
1528            };
1529            for update in proto::split_worktree_update(message) {
1530                session.peer.send(session.connection_id, update)?;
1531            }
1532
1533            // Stream this worktree's diagnostics.
1534            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1535            if let Some(summary) = worktree_diagnostics.next() {
1536                let message = proto::UpdateDiagnosticSummary {
1537                    project_id: project.id.to_proto(),
1538                    worktree_id: worktree.id,
1539                    summary: Some(summary),
1540                    more_summaries: worktree_diagnostics.collect(),
1541                };
1542                session.peer.send(session.connection_id, message)?;
1543            }
1544
1545            for settings_file in worktree.settings_files {
1546                session.peer.send(
1547                    session.connection_id,
1548                    proto::UpdateWorktreeSettings {
1549                        project_id: project.id.to_proto(),
1550                        worktree_id: worktree.id,
1551                        path: settings_file.path,
1552                        content: Some(settings_file.content),
1553                        kind: Some(settings_file.kind.to_proto().into()),
1554                    },
1555                )?;
1556            }
1557        }
1558
1559        for repository in mem::take(&mut project.updated_repositories) {
1560            for update in split_repository_update(repository) {
1561                session.peer.send(session.connection_id, update)?;
1562            }
1563        }
1564
1565        for id in mem::take(&mut project.removed_repositories) {
1566            session.peer.send(
1567                session.connection_id,
1568                proto::RemoveRepository {
1569                    project_id: project.id.to_proto(),
1570                    id,
1571                },
1572            )?;
1573        }
1574    }
1575
1576    Ok(())
1577}
1578
1579/// leave room disconnects from the room.
1580async fn leave_room(
1581    _: proto::LeaveRoom,
1582    response: Response<proto::LeaveRoom>,
1583    session: MessageContext,
1584) -> Result<()> {
1585    leave_room_for_session(&session, session.connection_id).await?;
1586    response.send(proto::Ack {})?;
1587    Ok(())
1588}
1589
1590/// Updates the permissions of someone else in the room.
1591async fn set_room_participant_role(
1592    request: proto::SetRoomParticipantRole,
1593    response: Response<proto::SetRoomParticipantRole>,
1594    session: MessageContext,
1595) -> Result<()> {
1596    let user_id = UserId::from_proto(request.user_id);
1597    let role = ChannelRole::from(request.role());
1598
1599    let (livekit_room, can_publish) = {
1600        let room = session
1601            .db()
1602            .await
1603            .set_room_participant_role(
1604                session.user_id(),
1605                RoomId::from_proto(request.room_id),
1606                user_id,
1607                role,
1608            )
1609            .await?;
1610
1611        let livekit_room = room.livekit_room.clone();
1612        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1613        room_updated(&room, &session.peer);
1614        (livekit_room, can_publish)
1615    };
1616
1617    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1618        live_kit
1619            .update_participant(
1620                livekit_room.clone(),
1621                request.user_id.to_string(),
1622                livekit_api::proto::ParticipantPermission {
1623                    can_subscribe: true,
1624                    can_publish,
1625                    can_publish_data: can_publish,
1626                    hidden: false,
1627                    recorder: false,
1628                },
1629            )
1630            .await
1631            .trace_err();
1632    }
1633
1634    response.send(proto::Ack {})?;
1635    Ok(())
1636}
1637
1638/// Call someone else into the current room
1639async fn call(
1640    request: proto::Call,
1641    response: Response<proto::Call>,
1642    session: MessageContext,
1643) -> Result<()> {
1644    let room_id = RoomId::from_proto(request.room_id);
1645    let calling_user_id = session.user_id();
1646    let calling_connection_id = session.connection_id;
1647    let called_user_id = UserId::from_proto(request.called_user_id);
1648    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1649    if !session
1650        .db()
1651        .await
1652        .has_contact(calling_user_id, called_user_id)
1653        .await?
1654    {
1655        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1656    }
1657
1658    let incoming_call = {
1659        let (room, incoming_call) = &mut *session
1660            .db()
1661            .await
1662            .call(
1663                room_id,
1664                calling_user_id,
1665                calling_connection_id,
1666                called_user_id,
1667                initial_project_id,
1668            )
1669            .await?;
1670        room_updated(room, &session.peer);
1671        mem::take(incoming_call)
1672    };
1673    update_user_contacts(called_user_id, &session).await?;
1674
1675    let mut calls = session
1676        .connection_pool()
1677        .await
1678        .user_connection_ids(called_user_id)
1679        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1680        .collect::<FuturesUnordered<_>>();
1681
1682    while let Some(call_response) = calls.next().await {
1683        match call_response.as_ref() {
1684            Ok(_) => {
1685                response.send(proto::Ack {})?;
1686                return Ok(());
1687            }
1688            Err(_) => {
1689                call_response.trace_err();
1690            }
1691        }
1692    }
1693
1694    {
1695        let room = session
1696            .db()
1697            .await
1698            .call_failed(room_id, called_user_id)
1699            .await?;
1700        room_updated(&room, &session.peer);
1701    }
1702    update_user_contacts(called_user_id, &session).await?;
1703
1704    Err(anyhow!("failed to ring user"))?
1705}
1706
1707/// Cancel an outgoing call.
1708async fn cancel_call(
1709    request: proto::CancelCall,
1710    response: Response<proto::CancelCall>,
1711    session: MessageContext,
1712) -> Result<()> {
1713    let called_user_id = UserId::from_proto(request.called_user_id);
1714    let room_id = RoomId::from_proto(request.room_id);
1715    {
1716        let room = session
1717            .db()
1718            .await
1719            .cancel_call(room_id, session.connection_id, called_user_id)
1720            .await?;
1721        room_updated(&room, &session.peer);
1722    }
1723
1724    for connection_id in session
1725        .connection_pool()
1726        .await
1727        .user_connection_ids(called_user_id)
1728    {
1729        session
1730            .peer
1731            .send(
1732                connection_id,
1733                proto::CallCanceled {
1734                    room_id: room_id.to_proto(),
1735                },
1736            )
1737            .trace_err();
1738    }
1739    response.send(proto::Ack {})?;
1740
1741    update_user_contacts(called_user_id, &session).await?;
1742    Ok(())
1743}
1744
1745/// Decline an incoming call.
1746async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1747    let room_id = RoomId::from_proto(message.room_id);
1748    {
1749        let room = session
1750            .db()
1751            .await
1752            .decline_call(Some(room_id), session.user_id())
1753            .await?
1754            .context("declining call")?;
1755        room_updated(&room, &session.peer);
1756    }
1757
1758    for connection_id in session
1759        .connection_pool()
1760        .await
1761        .user_connection_ids(session.user_id())
1762    {
1763        session
1764            .peer
1765            .send(
1766                connection_id,
1767                proto::CallCanceled {
1768                    room_id: room_id.to_proto(),
1769                },
1770            )
1771            .trace_err();
1772    }
1773    update_user_contacts(session.user_id(), &session).await?;
1774    Ok(())
1775}
1776
1777/// Updates other participants in the room with your current location.
1778async fn update_participant_location(
1779    request: proto::UpdateParticipantLocation,
1780    response: Response<proto::UpdateParticipantLocation>,
1781    session: MessageContext,
1782) -> Result<()> {
1783    let room_id = RoomId::from_proto(request.room_id);
1784    let location = request.location.context("invalid location")?;
1785
1786    let db = session.db().await;
1787    let room = db
1788        .update_room_participant_location(room_id, session.connection_id, location)
1789        .await?;
1790
1791    room_updated(&room, &session.peer);
1792    response.send(proto::Ack {})?;
1793    Ok(())
1794}
1795
1796/// Share a project into the room.
1797async fn share_project(
1798    request: proto::ShareProject,
1799    response: Response<proto::ShareProject>,
1800    session: MessageContext,
1801) -> Result<()> {
1802    let (project_id, room) = &*session
1803        .db()
1804        .await
1805        .share_project(
1806            RoomId::from_proto(request.room_id),
1807            session.connection_id,
1808            &request.worktrees,
1809            request.is_ssh_project,
1810            request.windows_paths.unwrap_or(false),
1811        )
1812        .await?;
1813    response.send(proto::ShareProjectResponse {
1814        project_id: project_id.to_proto(),
1815    })?;
1816    room_updated(room, &session.peer);
1817
1818    Ok(())
1819}
1820
1821/// Unshare a project from the room.
1822async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1823    let project_id = ProjectId::from_proto(message.project_id);
1824    unshare_project_internal(project_id, session.connection_id, &session).await
1825}
1826
1827async fn unshare_project_internal(
1828    project_id: ProjectId,
1829    connection_id: ConnectionId,
1830    session: &Session,
1831) -> Result<()> {
1832    let delete = {
1833        let room_guard = session
1834            .db()
1835            .await
1836            .unshare_project(project_id, connection_id)
1837            .await?;
1838
1839        let (delete, room, guest_connection_ids) = &*room_guard;
1840
1841        let message = proto::UnshareProject {
1842            project_id: project_id.to_proto(),
1843        };
1844
1845        broadcast(
1846            Some(connection_id),
1847            guest_connection_ids.iter().copied(),
1848            |conn_id| session.peer.send(conn_id, message.clone()),
1849        );
1850        if let Some(room) = room {
1851            room_updated(room, &session.peer);
1852        }
1853
1854        *delete
1855    };
1856
1857    if delete {
1858        let db = session.db().await;
1859        db.delete_project(project_id).await?;
1860    }
1861
1862    Ok(())
1863}
1864
1865/// Join someone elses shared project.
1866async fn join_project(
1867    request: proto::JoinProject,
1868    response: Response<proto::JoinProject>,
1869    session: MessageContext,
1870) -> Result<()> {
1871    let project_id = ProjectId::from_proto(request.project_id);
1872
1873    tracing::info!(%project_id, "join project");
1874
1875    let db = session.db().await;
1876    let (project, replica_id) = &mut *db
1877        .join_project(
1878            project_id,
1879            session.connection_id,
1880            session.user_id(),
1881            request.committer_name.clone(),
1882            request.committer_email.clone(),
1883        )
1884        .await?;
1885    drop(db);
1886    tracing::info!(%project_id, "join remote project");
1887    let collaborators = project
1888        .collaborators
1889        .iter()
1890        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1891        .map(|collaborator| collaborator.to_proto())
1892        .collect::<Vec<_>>();
1893    let project_id = project.id;
1894    let guest_user_id = session.user_id();
1895
1896    let worktrees = project
1897        .worktrees
1898        .iter()
1899        .map(|(id, worktree)| proto::WorktreeMetadata {
1900            id: *id,
1901            root_name: worktree.root_name.clone(),
1902            visible: worktree.visible,
1903            abs_path: worktree.abs_path.clone(),
1904        })
1905        .collect::<Vec<_>>();
1906
1907    let add_project_collaborator = proto::AddProjectCollaborator {
1908        project_id: project_id.to_proto(),
1909        collaborator: Some(proto::Collaborator {
1910            peer_id: Some(session.connection_id.into()),
1911            replica_id: replica_id.0 as u32,
1912            user_id: guest_user_id.to_proto(),
1913            is_host: false,
1914            committer_name: request.committer_name.clone(),
1915            committer_email: request.committer_email.clone(),
1916        }),
1917    };
1918
1919    for collaborator in &collaborators {
1920        session
1921            .peer
1922            .send(
1923                collaborator.peer_id.unwrap().into(),
1924                add_project_collaborator.clone(),
1925            )
1926            .trace_err();
1927    }
1928
1929    // First, we send the metadata associated with each worktree.
1930    let (language_servers, language_server_capabilities) = project
1931        .language_servers
1932        .clone()
1933        .into_iter()
1934        .map(|server| (server.server, server.capabilities))
1935        .unzip();
1936    response.send(proto::JoinProjectResponse {
1937        project_id: project.id.0 as u64,
1938        worktrees,
1939        replica_id: replica_id.0 as u32,
1940        collaborators,
1941        language_servers,
1942        language_server_capabilities,
1943        role: project.role.into(),
1944        windows_paths: project.path_style == PathStyle::Windows,
1945    })?;
1946
1947    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1948        // Stream this worktree's entries.
1949        let message = proto::UpdateWorktree {
1950            project_id: project_id.to_proto(),
1951            worktree_id,
1952            abs_path: worktree.abs_path.clone(),
1953            root_name: worktree.root_name,
1954            updated_entries: worktree.entries,
1955            removed_entries: Default::default(),
1956            scan_id: worktree.scan_id,
1957            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1958            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1959            removed_repositories: Default::default(),
1960        };
1961        for update in proto::split_worktree_update(message) {
1962            session.peer.send(session.connection_id, update.clone())?;
1963        }
1964
1965        // Stream this worktree's diagnostics.
1966        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1967        if let Some(summary) = worktree_diagnostics.next() {
1968            let message = proto::UpdateDiagnosticSummary {
1969                project_id: project.id.to_proto(),
1970                worktree_id: worktree.id,
1971                summary: Some(summary),
1972                more_summaries: worktree_diagnostics.collect(),
1973            };
1974            session.peer.send(session.connection_id, message)?;
1975        }
1976
1977        for settings_file in worktree.settings_files {
1978            session.peer.send(
1979                session.connection_id,
1980                proto::UpdateWorktreeSettings {
1981                    project_id: project_id.to_proto(),
1982                    worktree_id: worktree.id,
1983                    path: settings_file.path,
1984                    content: Some(settings_file.content),
1985                    kind: Some(settings_file.kind.to_proto() as i32),
1986                },
1987            )?;
1988        }
1989    }
1990
1991    for repository in mem::take(&mut project.repositories) {
1992        for update in split_repository_update(repository) {
1993            session.peer.send(session.connection_id, update)?;
1994        }
1995    }
1996
1997    for language_server in &project.language_servers {
1998        session.peer.send(
1999            session.connection_id,
2000            proto::UpdateLanguageServer {
2001                project_id: project_id.to_proto(),
2002                server_name: Some(language_server.server.name.clone()),
2003                language_server_id: language_server.server.id,
2004                variant: Some(
2005                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2006                        proto::LspDiskBasedDiagnosticsUpdated {},
2007                    ),
2008                ),
2009            },
2010        )?;
2011    }
2012
2013    Ok(())
2014}
2015
2016/// Leave someone elses shared project.
2017async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2018    let sender_id = session.connection_id;
2019    let project_id = ProjectId::from_proto(request.project_id);
2020    let db = session.db().await;
2021
2022    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2023    tracing::info!(
2024        %project_id,
2025        "leave project"
2026    );
2027
2028    project_left(project, &session);
2029    if let Some(room) = room {
2030        room_updated(room, &session.peer);
2031    }
2032
2033    Ok(())
2034}
2035
2036/// Updates other participants with changes to the project
2037async fn update_project(
2038    request: proto::UpdateProject,
2039    response: Response<proto::UpdateProject>,
2040    session: MessageContext,
2041) -> Result<()> {
2042    let project_id = ProjectId::from_proto(request.project_id);
2043    let (room, guest_connection_ids) = &*session
2044        .db()
2045        .await
2046        .update_project(project_id, session.connection_id, &request.worktrees)
2047        .await?;
2048    broadcast(
2049        Some(session.connection_id),
2050        guest_connection_ids.iter().copied(),
2051        |connection_id| {
2052            session
2053                .peer
2054                .forward_send(session.connection_id, connection_id, request.clone())
2055        },
2056    );
2057    if let Some(room) = room {
2058        room_updated(room, &session.peer);
2059    }
2060    response.send(proto::Ack {})?;
2061
2062    Ok(())
2063}
2064
2065/// Updates other participants with changes to the worktree
2066async fn update_worktree(
2067    request: proto::UpdateWorktree,
2068    response: Response<proto::UpdateWorktree>,
2069    session: MessageContext,
2070) -> Result<()> {
2071    let guest_connection_ids = session
2072        .db()
2073        .await
2074        .update_worktree(&request, session.connection_id)
2075        .await?;
2076
2077    broadcast(
2078        Some(session.connection_id),
2079        guest_connection_ids.iter().copied(),
2080        |connection_id| {
2081            session
2082                .peer
2083                .forward_send(session.connection_id, connection_id, request.clone())
2084        },
2085    );
2086    response.send(proto::Ack {})?;
2087    Ok(())
2088}
2089
2090async fn update_repository(
2091    request: proto::UpdateRepository,
2092    response: Response<proto::UpdateRepository>,
2093    session: MessageContext,
2094) -> Result<()> {
2095    let guest_connection_ids = session
2096        .db()
2097        .await
2098        .update_repository(&request, session.connection_id)
2099        .await?;
2100
2101    broadcast(
2102        Some(session.connection_id),
2103        guest_connection_ids.iter().copied(),
2104        |connection_id| {
2105            session
2106                .peer
2107                .forward_send(session.connection_id, connection_id, request.clone())
2108        },
2109    );
2110    response.send(proto::Ack {})?;
2111    Ok(())
2112}
2113
2114async fn remove_repository(
2115    request: proto::RemoveRepository,
2116    response: Response<proto::RemoveRepository>,
2117    session: MessageContext,
2118) -> Result<()> {
2119    let guest_connection_ids = session
2120        .db()
2121        .await
2122        .remove_repository(&request, session.connection_id)
2123        .await?;
2124
2125    broadcast(
2126        Some(session.connection_id),
2127        guest_connection_ids.iter().copied(),
2128        |connection_id| {
2129            session
2130                .peer
2131                .forward_send(session.connection_id, connection_id, request.clone())
2132        },
2133    );
2134    response.send(proto::Ack {})?;
2135    Ok(())
2136}
2137
2138/// Updates other participants with changes to the diagnostics
2139async fn update_diagnostic_summary(
2140    message: proto::UpdateDiagnosticSummary,
2141    session: MessageContext,
2142) -> Result<()> {
2143    let guest_connection_ids = session
2144        .db()
2145        .await
2146        .update_diagnostic_summary(&message, session.connection_id)
2147        .await?;
2148
2149    broadcast(
2150        Some(session.connection_id),
2151        guest_connection_ids.iter().copied(),
2152        |connection_id| {
2153            session
2154                .peer
2155                .forward_send(session.connection_id, connection_id, message.clone())
2156        },
2157    );
2158
2159    Ok(())
2160}
2161
2162/// Updates other participants with changes to the worktree settings
2163async fn update_worktree_settings(
2164    message: proto::UpdateWorktreeSettings,
2165    session: MessageContext,
2166) -> Result<()> {
2167    let guest_connection_ids = session
2168        .db()
2169        .await
2170        .update_worktree_settings(&message, session.connection_id)
2171        .await?;
2172
2173    broadcast(
2174        Some(session.connection_id),
2175        guest_connection_ids.iter().copied(),
2176        |connection_id| {
2177            session
2178                .peer
2179                .forward_send(session.connection_id, connection_id, message.clone())
2180        },
2181    );
2182
2183    Ok(())
2184}
2185
2186/// Notify other participants that a language server has started.
2187async fn start_language_server(
2188    request: proto::StartLanguageServer,
2189    session: MessageContext,
2190) -> Result<()> {
2191    let guest_connection_ids = session
2192        .db()
2193        .await
2194        .start_language_server(&request, session.connection_id)
2195        .await?;
2196
2197    broadcast(
2198        Some(session.connection_id),
2199        guest_connection_ids.iter().copied(),
2200        |connection_id| {
2201            session
2202                .peer
2203                .forward_send(session.connection_id, connection_id, request.clone())
2204        },
2205    );
2206    Ok(())
2207}
2208
2209/// Notify other participants that a language server has changed.
2210async fn update_language_server(
2211    request: proto::UpdateLanguageServer,
2212    session: MessageContext,
2213) -> Result<()> {
2214    let project_id = ProjectId::from_proto(request.project_id);
2215    let db = session.db().await;
2216
2217    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2218        && let Some(capabilities) = update.capabilities.clone()
2219    {
2220        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2221            .await?;
2222    }
2223
2224    let project_connection_ids = db
2225        .project_connection_ids(project_id, session.connection_id, true)
2226        .await?;
2227    broadcast(
2228        Some(session.connection_id),
2229        project_connection_ids.iter().copied(),
2230        |connection_id| {
2231            session
2232                .peer
2233                .forward_send(session.connection_id, connection_id, request.clone())
2234        },
2235    );
2236    Ok(())
2237}
2238
2239/// forward a project request to the host. These requests should be read only
2240/// as guests are allowed to send them.
2241async fn forward_read_only_project_request<T>(
2242    request: T,
2243    response: Response<T>,
2244    session: MessageContext,
2245) -> Result<()>
2246where
2247    T: EntityMessage + RequestMessage,
2248{
2249    let project_id = ProjectId::from_proto(request.remote_entity_id());
2250    let host_connection_id = session
2251        .db()
2252        .await
2253        .host_for_read_only_project_request(project_id, session.connection_id)
2254        .await?;
2255    let payload = session.forward_request(host_connection_id, request).await?;
2256    response.send(payload)?;
2257    Ok(())
2258}
2259
2260/// forward a project request to the host. These requests are disallowed
2261/// for guests.
2262async fn forward_mutating_project_request<T>(
2263    request: T,
2264    response: Response<T>,
2265    session: MessageContext,
2266) -> Result<()>
2267where
2268    T: EntityMessage + RequestMessage,
2269{
2270    let project_id = ProjectId::from_proto(request.remote_entity_id());
2271
2272    let host_connection_id = session
2273        .db()
2274        .await
2275        .host_for_mutating_project_request(project_id, session.connection_id)
2276        .await?;
2277    let payload = session.forward_request(host_connection_id, request).await?;
2278    response.send(payload)?;
2279    Ok(())
2280}
2281
2282async fn lsp_query(
2283    request: proto::LspQuery,
2284    response: Response<proto::LspQuery>,
2285    session: MessageContext,
2286) -> Result<()> {
2287    let (name, should_write) = request.query_name_and_write_permissions();
2288    tracing::Span::current().record("lsp_query_request", name);
2289    tracing::info!("lsp_query message received");
2290    if should_write {
2291        forward_mutating_project_request(request, response, session).await
2292    } else {
2293        forward_read_only_project_request(request, response, session).await
2294    }
2295}
2296
2297/// Notify other participants that a new buffer has been created
2298async fn create_buffer_for_peer(
2299    request: proto::CreateBufferForPeer,
2300    session: MessageContext,
2301) -> Result<()> {
2302    session
2303        .db()
2304        .await
2305        .check_user_is_project_host(
2306            ProjectId::from_proto(request.project_id),
2307            session.connection_id,
2308        )
2309        .await?;
2310    let peer_id = request.peer_id.context("invalid peer id")?;
2311    session
2312        .peer
2313        .forward_send(session.connection_id, peer_id.into(), request)?;
2314    Ok(())
2315}
2316
2317/// Notify other participants that a new image has been created
2318async fn create_image_for_peer(
2319    request: proto::CreateImageForPeer,
2320    session: MessageContext,
2321) -> Result<()> {
2322    session
2323        .db()
2324        .await
2325        .check_user_is_project_host(
2326            ProjectId::from_proto(request.project_id),
2327            session.connection_id,
2328        )
2329        .await?;
2330    let peer_id = request.peer_id.context("invalid peer id")?;
2331    session
2332        .peer
2333        .forward_send(session.connection_id, peer_id.into(), request)?;
2334    Ok(())
2335}
2336
2337/// Notify other participants that a buffer has been updated. This is
2338/// allowed for guests as long as the update is limited to selections.
2339async fn update_buffer(
2340    request: proto::UpdateBuffer,
2341    response: Response<proto::UpdateBuffer>,
2342    session: MessageContext,
2343) -> Result<()> {
2344    let project_id = ProjectId::from_proto(request.project_id);
2345    let mut capability = Capability::ReadOnly;
2346
2347    for op in request.operations.iter() {
2348        match op.variant {
2349            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2350            Some(_) => capability = Capability::ReadWrite,
2351        }
2352    }
2353
2354    let host = {
2355        let guard = session
2356            .db()
2357            .await
2358            .connections_for_buffer_update(project_id, session.connection_id, capability)
2359            .await?;
2360
2361        let (host, guests) = &*guard;
2362
2363        broadcast(
2364            Some(session.connection_id),
2365            guests.clone(),
2366            |connection_id| {
2367                session
2368                    .peer
2369                    .forward_send(session.connection_id, connection_id, request.clone())
2370            },
2371        );
2372
2373        *host
2374    };
2375
2376    if host != session.connection_id {
2377        session.forward_request(host, request.clone()).await?;
2378    }
2379
2380    response.send(proto::Ack {})?;
2381    Ok(())
2382}
2383
2384async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2385    let project_id = ProjectId::from_proto(message.project_id);
2386
2387    let operation = message.operation.as_ref().context("invalid operation")?;
2388    let capability = match operation.variant.as_ref() {
2389        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2390            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2391                match buffer_op.variant {
2392                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2393                        Capability::ReadOnly
2394                    }
2395                    _ => Capability::ReadWrite,
2396                }
2397            } else {
2398                Capability::ReadWrite
2399            }
2400        }
2401        Some(_) => Capability::ReadWrite,
2402        None => Capability::ReadOnly,
2403    };
2404
2405    let guard = session
2406        .db()
2407        .await
2408        .connections_for_buffer_update(project_id, session.connection_id, capability)
2409        .await?;
2410
2411    let (host, guests) = &*guard;
2412
2413    broadcast(
2414        Some(session.connection_id),
2415        guests.iter().chain([host]).copied(),
2416        |connection_id| {
2417            session
2418                .peer
2419                .forward_send(session.connection_id, connection_id, message.clone())
2420        },
2421    );
2422
2423    Ok(())
2424}
2425
2426/// Notify other participants that a project has been updated.
2427async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2428    request: T,
2429    session: MessageContext,
2430) -> Result<()> {
2431    let project_id = ProjectId::from_proto(request.remote_entity_id());
2432    let project_connection_ids = session
2433        .db()
2434        .await
2435        .project_connection_ids(project_id, session.connection_id, false)
2436        .await?;
2437
2438    broadcast(
2439        Some(session.connection_id),
2440        project_connection_ids.iter().copied(),
2441        |connection_id| {
2442            session
2443                .peer
2444                .forward_send(session.connection_id, connection_id, request.clone())
2445        },
2446    );
2447    Ok(())
2448}
2449
2450/// Start following another user in a call.
2451async fn follow(
2452    request: proto::Follow,
2453    response: Response<proto::Follow>,
2454    session: MessageContext,
2455) -> Result<()> {
2456    let room_id = RoomId::from_proto(request.room_id);
2457    let project_id = request.project_id.map(ProjectId::from_proto);
2458    let leader_id = request.leader_id.context("invalid leader id")?.into();
2459    let follower_id = session.connection_id;
2460
2461    session
2462        .db()
2463        .await
2464        .check_room_participants(room_id, leader_id, session.connection_id)
2465        .await?;
2466
2467    let response_payload = session.forward_request(leader_id, request).await?;
2468    response.send(response_payload)?;
2469
2470    if let Some(project_id) = project_id {
2471        let room = session
2472            .db()
2473            .await
2474            .follow(room_id, project_id, leader_id, follower_id)
2475            .await?;
2476        room_updated(&room, &session.peer);
2477    }
2478
2479    Ok(())
2480}
2481
2482/// Stop following another user in a call.
2483async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2484    let room_id = RoomId::from_proto(request.room_id);
2485    let project_id = request.project_id.map(ProjectId::from_proto);
2486    let leader_id = request.leader_id.context("invalid leader id")?.into();
2487    let follower_id = session.connection_id;
2488
2489    session
2490        .db()
2491        .await
2492        .check_room_participants(room_id, leader_id, session.connection_id)
2493        .await?;
2494
2495    session
2496        .peer
2497        .forward_send(session.connection_id, leader_id, request)?;
2498
2499    if let Some(project_id) = project_id {
2500        let room = session
2501            .db()
2502            .await
2503            .unfollow(room_id, project_id, leader_id, follower_id)
2504            .await?;
2505        room_updated(&room, &session.peer);
2506    }
2507
2508    Ok(())
2509}
2510
2511/// Notify everyone following you of your current location.
2512async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2513    let room_id = RoomId::from_proto(request.room_id);
2514    let database = session.db.lock().await;
2515
2516    let connection_ids = if let Some(project_id) = request.project_id {
2517        let project_id = ProjectId::from_proto(project_id);
2518        database
2519            .project_connection_ids(project_id, session.connection_id, true)
2520            .await?
2521    } else {
2522        database
2523            .room_connection_ids(room_id, session.connection_id)
2524            .await?
2525    };
2526
2527    // For now, don't send view update messages back to that view's current leader.
2528    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2529        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2530        _ => None,
2531    });
2532
2533    for connection_id in connection_ids.iter().cloned() {
2534        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2535            session
2536                .peer
2537                .forward_send(session.connection_id, connection_id, request.clone())?;
2538        }
2539    }
2540    Ok(())
2541}
2542
2543/// Get public data about users.
2544async fn get_users(
2545    request: proto::GetUsers,
2546    response: Response<proto::GetUsers>,
2547    session: MessageContext,
2548) -> Result<()> {
2549    let user_ids = request
2550        .user_ids
2551        .into_iter()
2552        .map(UserId::from_proto)
2553        .collect();
2554    let users = session
2555        .db()
2556        .await
2557        .get_users_by_ids(user_ids)
2558        .await?
2559        .into_iter()
2560        .map(|user| proto::User {
2561            id: user.id.to_proto(),
2562            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2563            github_login: user.github_login,
2564            name: user.name,
2565        })
2566        .collect();
2567    response.send(proto::UsersResponse { users })?;
2568    Ok(())
2569}
2570
2571/// Search for users (to invite) buy Github login
2572async fn fuzzy_search_users(
2573    request: proto::FuzzySearchUsers,
2574    response: Response<proto::FuzzySearchUsers>,
2575    session: MessageContext,
2576) -> Result<()> {
2577    let query = request.query;
2578    let users = match query.len() {
2579        0 => vec![],
2580        1 | 2 => session
2581            .db()
2582            .await
2583            .get_user_by_github_login(&query)
2584            .await?
2585            .into_iter()
2586            .collect(),
2587        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2588    };
2589    let users = users
2590        .into_iter()
2591        .filter(|user| user.id != session.user_id())
2592        .map(|user| proto::User {
2593            id: user.id.to_proto(),
2594            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2595            github_login: user.github_login,
2596            name: user.name,
2597        })
2598        .collect();
2599    response.send(proto::UsersResponse { users })?;
2600    Ok(())
2601}
2602
2603/// Send a contact request to another user.
2604async fn request_contact(
2605    request: proto::RequestContact,
2606    response: Response<proto::RequestContact>,
2607    session: MessageContext,
2608) -> Result<()> {
2609    let requester_id = session.user_id();
2610    let responder_id = UserId::from_proto(request.responder_id);
2611    if requester_id == responder_id {
2612        return Err(anyhow!("cannot add yourself as a contact"))?;
2613    }
2614
2615    let notifications = session
2616        .db()
2617        .await
2618        .send_contact_request(requester_id, responder_id)
2619        .await?;
2620
2621    // Update outgoing contact requests of requester
2622    let mut update = proto::UpdateContacts::default();
2623    update.outgoing_requests.push(responder_id.to_proto());
2624    for connection_id in session
2625        .connection_pool()
2626        .await
2627        .user_connection_ids(requester_id)
2628    {
2629        session.peer.send(connection_id, update.clone())?;
2630    }
2631
2632    // Update incoming contact requests of responder
2633    let mut update = proto::UpdateContacts::default();
2634    update
2635        .incoming_requests
2636        .push(proto::IncomingContactRequest {
2637            requester_id: requester_id.to_proto(),
2638        });
2639    let connection_pool = session.connection_pool().await;
2640    for connection_id in connection_pool.user_connection_ids(responder_id) {
2641        session.peer.send(connection_id, update.clone())?;
2642    }
2643
2644    send_notifications(&connection_pool, &session.peer, notifications);
2645
2646    response.send(proto::Ack {})?;
2647    Ok(())
2648}
2649
2650/// Accept or decline a contact request
2651async fn respond_to_contact_request(
2652    request: proto::RespondToContactRequest,
2653    response: Response<proto::RespondToContactRequest>,
2654    session: MessageContext,
2655) -> Result<()> {
2656    let responder_id = session.user_id();
2657    let requester_id = UserId::from_proto(request.requester_id);
2658    let db = session.db().await;
2659    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2660        db.dismiss_contact_notification(responder_id, requester_id)
2661            .await?;
2662    } else {
2663        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2664
2665        let notifications = db
2666            .respond_to_contact_request(responder_id, requester_id, accept)
2667            .await?;
2668        let requester_busy = db.is_user_busy(requester_id).await?;
2669        let responder_busy = db.is_user_busy(responder_id).await?;
2670
2671        let pool = session.connection_pool().await;
2672        // Update responder with new contact
2673        let mut update = proto::UpdateContacts::default();
2674        if accept {
2675            update
2676                .contacts
2677                .push(contact_for_user(requester_id, requester_busy, &pool));
2678        }
2679        update
2680            .remove_incoming_requests
2681            .push(requester_id.to_proto());
2682        for connection_id in pool.user_connection_ids(responder_id) {
2683            session.peer.send(connection_id, update.clone())?;
2684        }
2685
2686        // Update requester with new contact
2687        let mut update = proto::UpdateContacts::default();
2688        if accept {
2689            update
2690                .contacts
2691                .push(contact_for_user(responder_id, responder_busy, &pool));
2692        }
2693        update
2694            .remove_outgoing_requests
2695            .push(responder_id.to_proto());
2696
2697        for connection_id in pool.user_connection_ids(requester_id) {
2698            session.peer.send(connection_id, update.clone())?;
2699        }
2700
2701        send_notifications(&pool, &session.peer, notifications);
2702    }
2703
2704    response.send(proto::Ack {})?;
2705    Ok(())
2706}
2707
2708/// Remove a contact.
2709async fn remove_contact(
2710    request: proto::RemoveContact,
2711    response: Response<proto::RemoveContact>,
2712    session: MessageContext,
2713) -> Result<()> {
2714    let requester_id = session.user_id();
2715    let responder_id = UserId::from_proto(request.user_id);
2716    let db = session.db().await;
2717    let (contact_accepted, deleted_notification_id) =
2718        db.remove_contact(requester_id, responder_id).await?;
2719
2720    let pool = session.connection_pool().await;
2721    // Update outgoing contact requests of requester
2722    let mut update = proto::UpdateContacts::default();
2723    if contact_accepted {
2724        update.remove_contacts.push(responder_id.to_proto());
2725    } else {
2726        update
2727            .remove_outgoing_requests
2728            .push(responder_id.to_proto());
2729    }
2730    for connection_id in pool.user_connection_ids(requester_id) {
2731        session.peer.send(connection_id, update.clone())?;
2732    }
2733
2734    // Update incoming contact requests of responder
2735    let mut update = proto::UpdateContacts::default();
2736    if contact_accepted {
2737        update.remove_contacts.push(requester_id.to_proto());
2738    } else {
2739        update
2740            .remove_incoming_requests
2741            .push(requester_id.to_proto());
2742    }
2743    for connection_id in pool.user_connection_ids(responder_id) {
2744        session.peer.send(connection_id, update.clone())?;
2745        if let Some(notification_id) = deleted_notification_id {
2746            session.peer.send(
2747                connection_id,
2748                proto::DeleteNotification {
2749                    notification_id: notification_id.to_proto(),
2750                },
2751            )?;
2752        }
2753    }
2754
2755    response.send(proto::Ack {})?;
2756    Ok(())
2757}
2758
2759fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2760    version.0.minor < 139
2761}
2762
2763async fn subscribe_to_channels(
2764    _: proto::SubscribeToChannels,
2765    session: MessageContext,
2766) -> Result<()> {
2767    subscribe_user_to_channels(session.user_id(), &session).await?;
2768    Ok(())
2769}
2770
2771async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2772    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2773    let mut pool = session.connection_pool().await;
2774    for membership in &channels_for_user.channel_memberships {
2775        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2776    }
2777    session.peer.send(
2778        session.connection_id,
2779        build_update_user_channels(&channels_for_user),
2780    )?;
2781    session.peer.send(
2782        session.connection_id,
2783        build_channels_update(channels_for_user),
2784    )?;
2785    Ok(())
2786}
2787
2788/// Creates a new channel.
2789async fn create_channel(
2790    request: proto::CreateChannel,
2791    response: Response<proto::CreateChannel>,
2792    session: MessageContext,
2793) -> Result<()> {
2794    let db = session.db().await;
2795
2796    let parent_id = request.parent_id.map(ChannelId::from_proto);
2797    let (channel, membership) = db
2798        .create_channel(&request.name, parent_id, session.user_id())
2799        .await?;
2800
2801    let root_id = channel.root_id();
2802    let channel = Channel::from_model(channel);
2803
2804    response.send(proto::CreateChannelResponse {
2805        channel: Some(channel.to_proto()),
2806        parent_id: request.parent_id,
2807    })?;
2808
2809    let mut connection_pool = session.connection_pool().await;
2810    if let Some(membership) = membership {
2811        connection_pool.subscribe_to_channel(
2812            membership.user_id,
2813            membership.channel_id,
2814            membership.role,
2815        );
2816        let update = proto::UpdateUserChannels {
2817            channel_memberships: vec![proto::ChannelMembership {
2818                channel_id: membership.channel_id.to_proto(),
2819                role: membership.role.into(),
2820            }],
2821            ..Default::default()
2822        };
2823        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2824            session.peer.send(connection_id, update.clone())?;
2825        }
2826    }
2827
2828    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2829        if !role.can_see_channel(channel.visibility) {
2830            continue;
2831        }
2832
2833        let update = proto::UpdateChannels {
2834            channels: vec![channel.to_proto()],
2835            ..Default::default()
2836        };
2837        session.peer.send(connection_id, update.clone())?;
2838    }
2839
2840    Ok(())
2841}
2842
2843/// Delete a channel
2844async fn delete_channel(
2845    request: proto::DeleteChannel,
2846    response: Response<proto::DeleteChannel>,
2847    session: MessageContext,
2848) -> Result<()> {
2849    let db = session.db().await;
2850
2851    let channel_id = request.channel_id;
2852    let (root_channel, removed_channels) = db
2853        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2854        .await?;
2855    response.send(proto::Ack {})?;
2856
2857    // Notify members of removed channels
2858    let mut update = proto::UpdateChannels::default();
2859    update
2860        .delete_channels
2861        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2862
2863    let connection_pool = session.connection_pool().await;
2864    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2865        session.peer.send(connection_id, update.clone())?;
2866    }
2867
2868    Ok(())
2869}
2870
2871/// Invite someone to join a channel.
2872async fn invite_channel_member(
2873    request: proto::InviteChannelMember,
2874    response: Response<proto::InviteChannelMember>,
2875    session: MessageContext,
2876) -> Result<()> {
2877    let db = session.db().await;
2878    let channel_id = ChannelId::from_proto(request.channel_id);
2879    let invitee_id = UserId::from_proto(request.user_id);
2880    let InviteMemberResult {
2881        channel,
2882        notifications,
2883    } = db
2884        .invite_channel_member(
2885            channel_id,
2886            invitee_id,
2887            session.user_id(),
2888            request.role().into(),
2889        )
2890        .await?;
2891
2892    let update = proto::UpdateChannels {
2893        channel_invitations: vec![channel.to_proto()],
2894        ..Default::default()
2895    };
2896
2897    let connection_pool = session.connection_pool().await;
2898    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2899        session.peer.send(connection_id, update.clone())?;
2900    }
2901
2902    send_notifications(&connection_pool, &session.peer, notifications);
2903
2904    response.send(proto::Ack {})?;
2905    Ok(())
2906}
2907
2908/// remove someone from a channel
2909async fn remove_channel_member(
2910    request: proto::RemoveChannelMember,
2911    response: Response<proto::RemoveChannelMember>,
2912    session: MessageContext,
2913) -> Result<()> {
2914    let db = session.db().await;
2915    let channel_id = ChannelId::from_proto(request.channel_id);
2916    let member_id = UserId::from_proto(request.user_id);
2917
2918    let RemoveChannelMemberResult {
2919        membership_update,
2920        notification_id,
2921    } = db
2922        .remove_channel_member(channel_id, member_id, session.user_id())
2923        .await?;
2924
2925    let mut connection_pool = session.connection_pool().await;
2926    notify_membership_updated(
2927        &mut connection_pool,
2928        membership_update,
2929        member_id,
2930        &session.peer,
2931    );
2932    for connection_id in connection_pool.user_connection_ids(member_id) {
2933        if let Some(notification_id) = notification_id {
2934            session
2935                .peer
2936                .send(
2937                    connection_id,
2938                    proto::DeleteNotification {
2939                        notification_id: notification_id.to_proto(),
2940                    },
2941                )
2942                .trace_err();
2943        }
2944    }
2945
2946    response.send(proto::Ack {})?;
2947    Ok(())
2948}
2949
2950/// Toggle the channel between public and private.
2951/// Care is taken to maintain the invariant that public channels only descend from public channels,
2952/// (though members-only channels can appear at any point in the hierarchy).
2953async fn set_channel_visibility(
2954    request: proto::SetChannelVisibility,
2955    response: Response<proto::SetChannelVisibility>,
2956    session: MessageContext,
2957) -> Result<()> {
2958    let db = session.db().await;
2959    let channel_id = ChannelId::from_proto(request.channel_id);
2960    let visibility = request.visibility().into();
2961
2962    let channel_model = db
2963        .set_channel_visibility(channel_id, visibility, session.user_id())
2964        .await?;
2965    let root_id = channel_model.root_id();
2966    let channel = Channel::from_model(channel_model);
2967
2968    let mut connection_pool = session.connection_pool().await;
2969    for (user_id, role) in connection_pool
2970        .channel_user_ids(root_id)
2971        .collect::<Vec<_>>()
2972        .into_iter()
2973    {
2974        let update = if role.can_see_channel(channel.visibility) {
2975            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2976            proto::UpdateChannels {
2977                channels: vec![channel.to_proto()],
2978                ..Default::default()
2979            }
2980        } else {
2981            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2982            proto::UpdateChannels {
2983                delete_channels: vec![channel.id.to_proto()],
2984                ..Default::default()
2985            }
2986        };
2987
2988        for connection_id in connection_pool.user_connection_ids(user_id) {
2989            session.peer.send(connection_id, update.clone())?;
2990        }
2991    }
2992
2993    response.send(proto::Ack {})?;
2994    Ok(())
2995}
2996
2997/// Alter the role for a user in the channel.
2998async fn set_channel_member_role(
2999    request: proto::SetChannelMemberRole,
3000    response: Response<proto::SetChannelMemberRole>,
3001    session: MessageContext,
3002) -> Result<()> {
3003    let db = session.db().await;
3004    let channel_id = ChannelId::from_proto(request.channel_id);
3005    let member_id = UserId::from_proto(request.user_id);
3006    let result = db
3007        .set_channel_member_role(
3008            channel_id,
3009            session.user_id(),
3010            member_id,
3011            request.role().into(),
3012        )
3013        .await?;
3014
3015    match result {
3016        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3017            let mut connection_pool = session.connection_pool().await;
3018            notify_membership_updated(
3019                &mut connection_pool,
3020                membership_update,
3021                member_id,
3022                &session.peer,
3023            )
3024        }
3025        db::SetMemberRoleResult::InviteUpdated(channel) => {
3026            let update = proto::UpdateChannels {
3027                channel_invitations: vec![channel.to_proto()],
3028                ..Default::default()
3029            };
3030
3031            for connection_id in session
3032                .connection_pool()
3033                .await
3034                .user_connection_ids(member_id)
3035            {
3036                session.peer.send(connection_id, update.clone())?;
3037            }
3038        }
3039    }
3040
3041    response.send(proto::Ack {})?;
3042    Ok(())
3043}
3044
3045/// Change the name of a channel
3046async fn rename_channel(
3047    request: proto::RenameChannel,
3048    response: Response<proto::RenameChannel>,
3049    session: MessageContext,
3050) -> Result<()> {
3051    let db = session.db().await;
3052    let channel_id = ChannelId::from_proto(request.channel_id);
3053    let channel_model = db
3054        .rename_channel(channel_id, session.user_id(), &request.name)
3055        .await?;
3056    let root_id = channel_model.root_id();
3057    let channel = Channel::from_model(channel_model);
3058
3059    response.send(proto::RenameChannelResponse {
3060        channel: Some(channel.to_proto()),
3061    })?;
3062
3063    let connection_pool = session.connection_pool().await;
3064    let update = proto::UpdateChannels {
3065        channels: vec![channel.to_proto()],
3066        ..Default::default()
3067    };
3068    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3069        if role.can_see_channel(channel.visibility) {
3070            session.peer.send(connection_id, update.clone())?;
3071        }
3072    }
3073
3074    Ok(())
3075}
3076
3077/// Move a channel to a new parent.
3078async fn move_channel(
3079    request: proto::MoveChannel,
3080    response: Response<proto::MoveChannel>,
3081    session: MessageContext,
3082) -> Result<()> {
3083    let channel_id = ChannelId::from_proto(request.channel_id);
3084    let to = ChannelId::from_proto(request.to);
3085
3086    let (root_id, channels) = session
3087        .db()
3088        .await
3089        .move_channel(channel_id, to, session.user_id())
3090        .await?;
3091
3092    let connection_pool = session.connection_pool().await;
3093    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3094        let channels = channels
3095            .iter()
3096            .filter_map(|channel| {
3097                if role.can_see_channel(channel.visibility) {
3098                    Some(channel.to_proto())
3099                } else {
3100                    None
3101                }
3102            })
3103            .collect::<Vec<_>>();
3104        if channels.is_empty() {
3105            continue;
3106        }
3107
3108        let update = proto::UpdateChannels {
3109            channels,
3110            ..Default::default()
3111        };
3112
3113        session.peer.send(connection_id, update.clone())?;
3114    }
3115
3116    response.send(Ack {})?;
3117    Ok(())
3118}
3119
3120async fn reorder_channel(
3121    request: proto::ReorderChannel,
3122    response: Response<proto::ReorderChannel>,
3123    session: MessageContext,
3124) -> Result<()> {
3125    let channel_id = ChannelId::from_proto(request.channel_id);
3126    let direction = request.direction();
3127
3128    let updated_channels = session
3129        .db()
3130        .await
3131        .reorder_channel(channel_id, direction, session.user_id())
3132        .await?;
3133
3134    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3135        let connection_pool = session.connection_pool().await;
3136        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3137            let channels = updated_channels
3138                .iter()
3139                .filter_map(|channel| {
3140                    if role.can_see_channel(channel.visibility) {
3141                        Some(channel.to_proto())
3142                    } else {
3143                        None
3144                    }
3145                })
3146                .collect::<Vec<_>>();
3147
3148            if channels.is_empty() {
3149                continue;
3150            }
3151
3152            let update = proto::UpdateChannels {
3153                channels,
3154                ..Default::default()
3155            };
3156
3157            session.peer.send(connection_id, update.clone())?;
3158        }
3159    }
3160
3161    response.send(Ack {})?;
3162    Ok(())
3163}
3164
3165/// Get the list of channel members
3166async fn get_channel_members(
3167    request: proto::GetChannelMembers,
3168    response: Response<proto::GetChannelMembers>,
3169    session: MessageContext,
3170) -> Result<()> {
3171    let db = session.db().await;
3172    let channel_id = ChannelId::from_proto(request.channel_id);
3173    let limit = if request.limit == 0 {
3174        u16::MAX as u64
3175    } else {
3176        request.limit
3177    };
3178    let (members, users) = db
3179        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3180        .await?;
3181    response.send(proto::GetChannelMembersResponse { members, users })?;
3182    Ok(())
3183}
3184
3185/// Accept or decline a channel invitation.
3186async fn respond_to_channel_invite(
3187    request: proto::RespondToChannelInvite,
3188    response: Response<proto::RespondToChannelInvite>,
3189    session: MessageContext,
3190) -> Result<()> {
3191    let db = session.db().await;
3192    let channel_id = ChannelId::from_proto(request.channel_id);
3193    let RespondToChannelInvite {
3194        membership_update,
3195        notifications,
3196    } = db
3197        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3198        .await?;
3199
3200    let mut connection_pool = session.connection_pool().await;
3201    if let Some(membership_update) = membership_update {
3202        notify_membership_updated(
3203            &mut connection_pool,
3204            membership_update,
3205            session.user_id(),
3206            &session.peer,
3207        );
3208    } else {
3209        let update = proto::UpdateChannels {
3210            remove_channel_invitations: vec![channel_id.to_proto()],
3211            ..Default::default()
3212        };
3213
3214        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3215            session.peer.send(connection_id, update.clone())?;
3216        }
3217    };
3218
3219    send_notifications(&connection_pool, &session.peer, notifications);
3220
3221    response.send(proto::Ack {})?;
3222
3223    Ok(())
3224}
3225
3226/// Join the channels' room
3227async fn join_channel(
3228    request: proto::JoinChannel,
3229    response: Response<proto::JoinChannel>,
3230    session: MessageContext,
3231) -> Result<()> {
3232    let channel_id = ChannelId::from_proto(request.channel_id);
3233    join_channel_internal(channel_id, Box::new(response), session).await
3234}
3235
3236trait JoinChannelInternalResponse {
3237    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3238}
3239impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3240    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3241        Response::<proto::JoinChannel>::send(self, result)
3242    }
3243}
3244impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3245    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3246        Response::<proto::JoinRoom>::send(self, result)
3247    }
3248}
3249
3250async fn join_channel_internal(
3251    channel_id: ChannelId,
3252    response: Box<impl JoinChannelInternalResponse>,
3253    session: MessageContext,
3254) -> Result<()> {
3255    let joined_room = {
3256        let mut db = session.db().await;
3257        // If zed quits without leaving the room, and the user re-opens zed before the
3258        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3259        // room they were in.
3260        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3261            tracing::info!(
3262                stale_connection_id = %connection,
3263                "cleaning up stale connection",
3264            );
3265            drop(db);
3266            leave_room_for_session(&session, connection).await?;
3267            db = session.db().await;
3268        }
3269
3270        let (joined_room, membership_updated, role) = db
3271            .join_channel(channel_id, session.user_id(), session.connection_id)
3272            .await?;
3273
3274        let live_kit_connection_info =
3275            session
3276                .app_state
3277                .livekit_client
3278                .as_ref()
3279                .and_then(|live_kit| {
3280                    let (can_publish, token) = if role == ChannelRole::Guest {
3281                        (
3282                            false,
3283                            live_kit
3284                                .guest_token(
3285                                    &joined_room.room.livekit_room,
3286                                    &session.user_id().to_string(),
3287                                )
3288                                .trace_err()?,
3289                        )
3290                    } else {
3291                        (
3292                            true,
3293                            live_kit
3294                                .room_token(
3295                                    &joined_room.room.livekit_room,
3296                                    &session.user_id().to_string(),
3297                                )
3298                                .trace_err()?,
3299                        )
3300                    };
3301
3302                    Some(LiveKitConnectionInfo {
3303                        server_url: live_kit.url().into(),
3304                        token,
3305                        can_publish,
3306                    })
3307                });
3308
3309        response.send(proto::JoinRoomResponse {
3310            room: Some(joined_room.room.clone()),
3311            channel_id: joined_room
3312                .channel
3313                .as_ref()
3314                .map(|channel| channel.id.to_proto()),
3315            live_kit_connection_info,
3316        })?;
3317
3318        let mut connection_pool = session.connection_pool().await;
3319        if let Some(membership_updated) = membership_updated {
3320            notify_membership_updated(
3321                &mut connection_pool,
3322                membership_updated,
3323                session.user_id(),
3324                &session.peer,
3325            );
3326        }
3327
3328        room_updated(&joined_room.room, &session.peer);
3329
3330        joined_room
3331    };
3332
3333    channel_updated(
3334        &joined_room.channel.context("channel not returned")?,
3335        &joined_room.room,
3336        &session.peer,
3337        &*session.connection_pool().await,
3338    );
3339
3340    update_user_contacts(session.user_id(), &session).await?;
3341    Ok(())
3342}
3343
3344/// Start editing the channel notes
3345async fn join_channel_buffer(
3346    request: proto::JoinChannelBuffer,
3347    response: Response<proto::JoinChannelBuffer>,
3348    session: MessageContext,
3349) -> Result<()> {
3350    let db = session.db().await;
3351    let channel_id = ChannelId::from_proto(request.channel_id);
3352
3353    let open_response = db
3354        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3355        .await?;
3356
3357    let collaborators = open_response.collaborators.clone();
3358    response.send(open_response)?;
3359
3360    let update = UpdateChannelBufferCollaborators {
3361        channel_id: channel_id.to_proto(),
3362        collaborators: collaborators.clone(),
3363    };
3364    channel_buffer_updated(
3365        session.connection_id,
3366        collaborators
3367            .iter()
3368            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3369        &update,
3370        &session.peer,
3371    );
3372
3373    Ok(())
3374}
3375
3376/// Edit the channel notes
3377async fn update_channel_buffer(
3378    request: proto::UpdateChannelBuffer,
3379    session: MessageContext,
3380) -> Result<()> {
3381    let db = session.db().await;
3382    let channel_id = ChannelId::from_proto(request.channel_id);
3383
3384    let (collaborators, epoch, version) = db
3385        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3386        .await?;
3387
3388    channel_buffer_updated(
3389        session.connection_id,
3390        collaborators.clone(),
3391        &proto::UpdateChannelBuffer {
3392            channel_id: channel_id.to_proto(),
3393            operations: request.operations,
3394        },
3395        &session.peer,
3396    );
3397
3398    let pool = &*session.connection_pool().await;
3399
3400    let non_collaborators =
3401        pool.channel_connection_ids(channel_id)
3402            .filter_map(|(connection_id, _)| {
3403                if collaborators.contains(&connection_id) {
3404                    None
3405                } else {
3406                    Some(connection_id)
3407                }
3408            });
3409
3410    broadcast(None, non_collaborators, |peer_id| {
3411        session.peer.send(
3412            peer_id,
3413            proto::UpdateChannels {
3414                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3415                    channel_id: channel_id.to_proto(),
3416                    epoch: epoch as u64,
3417                    version: version.clone(),
3418                }],
3419                ..Default::default()
3420            },
3421        )
3422    });
3423
3424    Ok(())
3425}
3426
3427/// Rejoin the channel notes after a connection blip
3428async fn rejoin_channel_buffers(
3429    request: proto::RejoinChannelBuffers,
3430    response: Response<proto::RejoinChannelBuffers>,
3431    session: MessageContext,
3432) -> Result<()> {
3433    let db = session.db().await;
3434    let buffers = db
3435        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3436        .await?;
3437
3438    for rejoined_buffer in &buffers {
3439        let collaborators_to_notify = rejoined_buffer
3440            .buffer
3441            .collaborators
3442            .iter()
3443            .filter_map(|c| Some(c.peer_id?.into()));
3444        channel_buffer_updated(
3445            session.connection_id,
3446            collaborators_to_notify,
3447            &proto::UpdateChannelBufferCollaborators {
3448                channel_id: rejoined_buffer.buffer.channel_id,
3449                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3450            },
3451            &session.peer,
3452        );
3453    }
3454
3455    response.send(proto::RejoinChannelBuffersResponse {
3456        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3457    })?;
3458
3459    Ok(())
3460}
3461
3462/// Stop editing the channel notes
3463async fn leave_channel_buffer(
3464    request: proto::LeaveChannelBuffer,
3465    response: Response<proto::LeaveChannelBuffer>,
3466    session: MessageContext,
3467) -> Result<()> {
3468    let db = session.db().await;
3469    let channel_id = ChannelId::from_proto(request.channel_id);
3470
3471    let left_buffer = db
3472        .leave_channel_buffer(channel_id, session.connection_id)
3473        .await?;
3474
3475    response.send(Ack {})?;
3476
3477    channel_buffer_updated(
3478        session.connection_id,
3479        left_buffer.connections,
3480        &proto::UpdateChannelBufferCollaborators {
3481            channel_id: channel_id.to_proto(),
3482            collaborators: left_buffer.collaborators,
3483        },
3484        &session.peer,
3485    );
3486
3487    Ok(())
3488}
3489
3490fn channel_buffer_updated<T: EnvelopedMessage>(
3491    sender_id: ConnectionId,
3492    collaborators: impl IntoIterator<Item = ConnectionId>,
3493    message: &T,
3494    peer: &Peer,
3495) {
3496    broadcast(Some(sender_id), collaborators, |peer_id| {
3497        peer.send(peer_id, message.clone())
3498    });
3499}
3500
3501fn send_notifications(
3502    connection_pool: &ConnectionPool,
3503    peer: &Peer,
3504    notifications: db::NotificationBatch,
3505) {
3506    for (user_id, notification) in notifications {
3507        for connection_id in connection_pool.user_connection_ids(user_id) {
3508            if let Err(error) = peer.send(
3509                connection_id,
3510                proto::AddNotification {
3511                    notification: Some(notification.clone()),
3512                },
3513            ) {
3514                tracing::error!(
3515                    "failed to send notification to {:?} {}",
3516                    connection_id,
3517                    error
3518                );
3519            }
3520        }
3521    }
3522}
3523
3524/// Send a message to the channel
3525async fn send_channel_message(
3526    _request: proto::SendChannelMessage,
3527    _response: Response<proto::SendChannelMessage>,
3528    _session: MessageContext,
3529) -> Result<()> {
3530    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3531}
3532
3533/// Delete a channel message
3534async fn remove_channel_message(
3535    _request: proto::RemoveChannelMessage,
3536    _response: Response<proto::RemoveChannelMessage>,
3537    _session: MessageContext,
3538) -> Result<()> {
3539    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3540}
3541
3542async fn update_channel_message(
3543    _request: proto::UpdateChannelMessage,
3544    _response: Response<proto::UpdateChannelMessage>,
3545    _session: MessageContext,
3546) -> Result<()> {
3547    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3548}
3549
3550/// Mark a channel message as read
3551async fn acknowledge_channel_message(
3552    _request: proto::AckChannelMessage,
3553    _session: MessageContext,
3554) -> Result<()> {
3555    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3556}
3557
3558/// Mark a buffer version as synced
3559async fn acknowledge_buffer_version(
3560    request: proto::AckBufferOperation,
3561    session: MessageContext,
3562) -> Result<()> {
3563    let buffer_id = BufferId::from_proto(request.buffer_id);
3564    session
3565        .db()
3566        .await
3567        .observe_buffer_version(
3568            buffer_id,
3569            session.user_id(),
3570            request.epoch as i32,
3571            &request.version,
3572        )
3573        .await?;
3574    Ok(())
3575}
3576
3577/// Start receiving chat updates for a channel
3578async fn join_channel_chat(
3579    _request: proto::JoinChannelChat,
3580    _response: Response<proto::JoinChannelChat>,
3581    _session: MessageContext,
3582) -> Result<()> {
3583    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3584}
3585
3586/// Stop receiving chat updates for a channel
3587async fn leave_channel_chat(
3588    _request: proto::LeaveChannelChat,
3589    _session: MessageContext,
3590) -> Result<()> {
3591    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3592}
3593
3594/// Retrieve the chat history for a channel
3595async fn get_channel_messages(
3596    _request: proto::GetChannelMessages,
3597    _response: Response<proto::GetChannelMessages>,
3598    _session: MessageContext,
3599) -> Result<()> {
3600    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3601}
3602
3603/// Retrieve specific chat messages
3604async fn get_channel_messages_by_id(
3605    _request: proto::GetChannelMessagesById,
3606    _response: Response<proto::GetChannelMessagesById>,
3607    _session: MessageContext,
3608) -> Result<()> {
3609    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3610}
3611
3612/// Retrieve the current users notifications
3613async fn get_notifications(
3614    request: proto::GetNotifications,
3615    response: Response<proto::GetNotifications>,
3616    session: MessageContext,
3617) -> Result<()> {
3618    let notifications = session
3619        .db()
3620        .await
3621        .get_notifications(
3622            session.user_id(),
3623            NOTIFICATION_COUNT_PER_PAGE,
3624            request.before_id.map(db::NotificationId::from_proto),
3625        )
3626        .await?;
3627    response.send(proto::GetNotificationsResponse {
3628        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3629        notifications,
3630    })?;
3631    Ok(())
3632}
3633
3634/// Mark notifications as read
3635async fn mark_notification_as_read(
3636    request: proto::MarkNotificationRead,
3637    response: Response<proto::MarkNotificationRead>,
3638    session: MessageContext,
3639) -> Result<()> {
3640    let database = &session.db().await;
3641    let notifications = database
3642        .mark_notification_as_read_by_id(
3643            session.user_id(),
3644            NotificationId::from_proto(request.notification_id),
3645        )
3646        .await?;
3647    send_notifications(
3648        &*session.connection_pool().await,
3649        &session.peer,
3650        notifications,
3651    );
3652    response.send(proto::Ack {})?;
3653    Ok(())
3654}
3655
3656fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3657    let message = match message {
3658        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3659        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3660        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3661        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3662        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3663            code: frame.code.into(),
3664            reason: frame.reason.as_str().to_owned().into(),
3665        })),
3666        // We should never receive a frame while reading the message, according
3667        // to the `tungstenite` maintainers:
3668        //
3669        // > It cannot occur when you read messages from the WebSocket, but it
3670        // > can be used when you want to send the raw frames (e.g. you want to
3671        // > send the frames to the WebSocket without composing the full message first).
3672        // >
3673        // > — https://github.com/snapview/tungstenite-rs/issues/268
3674        TungsteniteMessage::Frame(_) => {
3675            bail!("received an unexpected frame while reading the message")
3676        }
3677    };
3678
3679    Ok(message)
3680}
3681
3682fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3683    match message {
3684        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3685        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3686        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3687        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3688        AxumMessage::Close(frame) => {
3689            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3690                code: frame.code.into(),
3691                reason: frame.reason.as_ref().into(),
3692            }))
3693        }
3694    }
3695}
3696
3697fn notify_membership_updated(
3698    connection_pool: &mut ConnectionPool,
3699    result: MembershipUpdated,
3700    user_id: UserId,
3701    peer: &Peer,
3702) {
3703    for membership in &result.new_channels.channel_memberships {
3704        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3705    }
3706    for channel_id in &result.removed_channels {
3707        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3708    }
3709
3710    let user_channels_update = proto::UpdateUserChannels {
3711        channel_memberships: result
3712            .new_channels
3713            .channel_memberships
3714            .iter()
3715            .map(|cm| proto::ChannelMembership {
3716                channel_id: cm.channel_id.to_proto(),
3717                role: cm.role.into(),
3718            })
3719            .collect(),
3720        ..Default::default()
3721    };
3722
3723    let mut update = build_channels_update(result.new_channels);
3724    update.delete_channels = result
3725        .removed_channels
3726        .into_iter()
3727        .map(|id| id.to_proto())
3728        .collect();
3729    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3730
3731    for connection_id in connection_pool.user_connection_ids(user_id) {
3732        peer.send(connection_id, user_channels_update.clone())
3733            .trace_err();
3734        peer.send(connection_id, update.clone()).trace_err();
3735    }
3736}
3737
3738fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3739    proto::UpdateUserChannels {
3740        channel_memberships: channels
3741            .channel_memberships
3742            .iter()
3743            .map(|m| proto::ChannelMembership {
3744                channel_id: m.channel_id.to_proto(),
3745                role: m.role.into(),
3746            })
3747            .collect(),
3748        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3749    }
3750}
3751
3752fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3753    let mut update = proto::UpdateChannels::default();
3754
3755    for channel in channels.channels {
3756        update.channels.push(channel.to_proto());
3757    }
3758
3759    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3760
3761    for (channel_id, participants) in channels.channel_participants {
3762        update
3763            .channel_participants
3764            .push(proto::ChannelParticipants {
3765                channel_id: channel_id.to_proto(),
3766                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3767            });
3768    }
3769
3770    for channel in channels.invited_channels {
3771        update.channel_invitations.push(channel.to_proto());
3772    }
3773
3774    update
3775}
3776
3777fn build_initial_contacts_update(
3778    contacts: Vec<db::Contact>,
3779    pool: &ConnectionPool,
3780) -> proto::UpdateContacts {
3781    let mut update = proto::UpdateContacts::default();
3782
3783    for contact in contacts {
3784        match contact {
3785            db::Contact::Accepted { user_id, busy } => {
3786                update.contacts.push(contact_for_user(user_id, busy, pool));
3787            }
3788            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3789            db::Contact::Incoming { user_id } => {
3790                update
3791                    .incoming_requests
3792                    .push(proto::IncomingContactRequest {
3793                        requester_id: user_id.to_proto(),
3794                    })
3795            }
3796        }
3797    }
3798
3799    update
3800}
3801
3802fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3803    proto::Contact {
3804        user_id: user_id.to_proto(),
3805        online: pool.is_user_online(user_id),
3806        busy,
3807    }
3808}
3809
3810fn room_updated(room: &proto::Room, peer: &Peer) {
3811    broadcast(
3812        None,
3813        room.participants
3814            .iter()
3815            .filter_map(|participant| Some(participant.peer_id?.into())),
3816        |peer_id| {
3817            peer.send(
3818                peer_id,
3819                proto::RoomUpdated {
3820                    room: Some(room.clone()),
3821                },
3822            )
3823        },
3824    );
3825}
3826
3827fn channel_updated(
3828    channel: &db::channel::Model,
3829    room: &proto::Room,
3830    peer: &Peer,
3831    pool: &ConnectionPool,
3832) {
3833    let participants = room
3834        .participants
3835        .iter()
3836        .map(|p| p.user_id)
3837        .collect::<Vec<_>>();
3838
3839    broadcast(
3840        None,
3841        pool.channel_connection_ids(channel.root_id())
3842            .filter_map(|(channel_id, role)| {
3843                role.can_see_channel(channel.visibility)
3844                    .then_some(channel_id)
3845            }),
3846        |peer_id| {
3847            peer.send(
3848                peer_id,
3849                proto::UpdateChannels {
3850                    channel_participants: vec![proto::ChannelParticipants {
3851                        channel_id: channel.id.to_proto(),
3852                        participant_user_ids: participants.clone(),
3853                    }],
3854                    ..Default::default()
3855                },
3856            )
3857        },
3858    );
3859}
3860
3861async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3862    let db = session.db().await;
3863
3864    let contacts = db.get_contacts(user_id).await?;
3865    let busy = db.is_user_busy(user_id).await?;
3866
3867    let pool = session.connection_pool().await;
3868    let updated_contact = contact_for_user(user_id, busy, &pool);
3869    for contact in contacts {
3870        if let db::Contact::Accepted {
3871            user_id: contact_user_id,
3872            ..
3873        } = contact
3874        {
3875            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3876                session
3877                    .peer
3878                    .send(
3879                        contact_conn_id,
3880                        proto::UpdateContacts {
3881                            contacts: vec![updated_contact.clone()],
3882                            remove_contacts: Default::default(),
3883                            incoming_requests: Default::default(),
3884                            remove_incoming_requests: Default::default(),
3885                            outgoing_requests: Default::default(),
3886                            remove_outgoing_requests: Default::default(),
3887                        },
3888                    )
3889                    .trace_err();
3890            }
3891        }
3892    }
3893    Ok(())
3894}
3895
3896async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3897    let mut contacts_to_update = HashSet::default();
3898
3899    let room_id;
3900    let canceled_calls_to_user_ids;
3901    let livekit_room;
3902    let delete_livekit_room;
3903    let room;
3904    let channel;
3905
3906    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3907        contacts_to_update.insert(session.user_id());
3908
3909        for project in left_room.left_projects.values() {
3910            project_left(project, session);
3911        }
3912
3913        room_id = RoomId::from_proto(left_room.room.id);
3914        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3915        livekit_room = mem::take(&mut left_room.room.livekit_room);
3916        delete_livekit_room = left_room.deleted;
3917        room = mem::take(&mut left_room.room);
3918        channel = mem::take(&mut left_room.channel);
3919
3920        room_updated(&room, &session.peer);
3921    } else {
3922        return Ok(());
3923    }
3924
3925    if let Some(channel) = channel {
3926        channel_updated(
3927            &channel,
3928            &room,
3929            &session.peer,
3930            &*session.connection_pool().await,
3931        );
3932    }
3933
3934    {
3935        let pool = session.connection_pool().await;
3936        for canceled_user_id in canceled_calls_to_user_ids {
3937            for connection_id in pool.user_connection_ids(canceled_user_id) {
3938                session
3939                    .peer
3940                    .send(
3941                        connection_id,
3942                        proto::CallCanceled {
3943                            room_id: room_id.to_proto(),
3944                        },
3945                    )
3946                    .trace_err();
3947            }
3948            contacts_to_update.insert(canceled_user_id);
3949        }
3950    }
3951
3952    for contact_user_id in contacts_to_update {
3953        update_user_contacts(contact_user_id, session).await?;
3954    }
3955
3956    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3957        live_kit
3958            .remove_participant(livekit_room.clone(), session.user_id().to_string())
3959            .await
3960            .trace_err();
3961
3962        if delete_livekit_room {
3963            live_kit.delete_room(livekit_room).await.trace_err();
3964        }
3965    }
3966
3967    Ok(())
3968}
3969
3970async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3971    let left_channel_buffers = session
3972        .db()
3973        .await
3974        .leave_channel_buffers(session.connection_id)
3975        .await?;
3976
3977    for left_buffer in left_channel_buffers {
3978        channel_buffer_updated(
3979            session.connection_id,
3980            left_buffer.connections,
3981            &proto::UpdateChannelBufferCollaborators {
3982                channel_id: left_buffer.channel_id.to_proto(),
3983                collaborators: left_buffer.collaborators,
3984            },
3985            &session.peer,
3986        );
3987    }
3988
3989    Ok(())
3990}
3991
3992fn project_left(project: &db::LeftProject, session: &Session) {
3993    for connection_id in &project.connection_ids {
3994        if project.should_unshare {
3995            session
3996                .peer
3997                .send(
3998                    *connection_id,
3999                    proto::UnshareProject {
4000                        project_id: project.id.to_proto(),
4001                    },
4002                )
4003                .trace_err();
4004        } else {
4005            session
4006                .peer
4007                .send(
4008                    *connection_id,
4009                    proto::RemoveProjectCollaborator {
4010                        project_id: project.id.to_proto(),
4011                        peer_id: Some(session.connection_id.into()),
4012                    },
4013                )
4014                .trace_err();
4015        }
4016    }
4017}
4018
4019pub trait ResultExt {
4020    type Ok;
4021
4022    fn trace_err(self) -> Option<Self::Ok>;
4023}
4024
4025impl<T, E> ResultExt for Result<T, E>
4026where
4027    E: std::fmt::Debug,
4028{
4029    type Ok = T;
4030
4031    #[track_caller]
4032    fn trace_err(self) -> Option<T> {
4033        match self {
4034            Ok(value) => Some(value),
4035            Err(error) => {
4036                tracing::error!("{:?}", error);
4037                None
4038            }
4039        }
4040    }
4041}