rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::{
   5    AppState, Error, Result, auth,
   6    db::{
   7        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
   8        InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
   9        RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
  10        UserId,
  11    },
  12    executor::Executor,
  13};
  14use anyhow::{Context as _, anyhow, bail};
  15use async_tungstenite::tungstenite::{
  16    Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
  17};
  18use axum::headers::UserAgent;
  19use axum::{
  20    Extension, Router, TypedHeader,
  21    body::Body,
  22    extract::{
  23        ConnectInfo, WebSocketUpgrade,
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use futures::TryFutureExt as _;
  36use rpc::proto::split_repository_update;
  37use tracing::Span;
  38use util::paths::PathStyle;
  39
  40use futures::{
  41    FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
  42    stream::FuturesUnordered,
  43};
  44use prometheus::{IntGauge, register_int_gauge};
  45use rpc::{
  46    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  49        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51};
  52use semver::Version;
  53use serde::{Serialize, Serializer};
  54use std::{
  55    any::TypeId,
  56    future::Future,
  57    marker::PhantomData,
  58    mem,
  59    net::SocketAddr,
  60    ops::{Deref, DerefMut},
  61    rc::Rc,
  62    sync::{
  63        Arc, OnceLock,
  64        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
  65    },
  66    time::{Duration, Instant},
  67};
  68use tokio::sync::{Semaphore, watch};
  69use tower::ServiceBuilder;
  70use tracing::{
  71    Instrument,
  72    field::{self},
  73    info_span, instrument,
  74};
  75
  76pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  77
  78// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  79pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  80
  81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  82const MAX_CONCURRENT_CONNECTIONS: usize = 512;
  83
  84static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
  85
  86const TOTAL_DURATION_MS: &str = "total_duration_ms";
  87const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
  88const QUEUE_DURATION_MS: &str = "queue_duration_ms";
  89const HOST_WAITING_MS: &str = "host_waiting_ms";
  90
  91type MessageHandler =
  92    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
  93
  94pub struct ConnectionGuard;
  95
  96impl ConnectionGuard {
  97    pub fn try_acquire() -> Result<Self, ()> {
  98        let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
  99        if current_connections >= MAX_CONCURRENT_CONNECTIONS {
 100            CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 101            tracing::error!(
 102                "too many concurrent connections: {}",
 103                current_connections + 1
 104            );
 105            return Err(());
 106        }
 107        Ok(ConnectionGuard)
 108    }
 109}
 110
 111impl Drop for ConnectionGuard {
 112    fn drop(&mut self) {
 113        CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
 114    }
 115}
 116
 117struct Response<R> {
 118    peer: Arc<Peer>,
 119    receipt: Receipt<R>,
 120    responded: Arc<AtomicBool>,
 121}
 122
 123impl<R: RequestMessage> Response<R> {
 124    fn send(self, payload: R::Response) -> Result<()> {
 125        self.responded.store(true, SeqCst);
 126        self.peer.respond(self.receipt, payload)?;
 127        Ok(())
 128    }
 129}
 130
 131#[derive(Clone, Debug)]
 132pub enum Principal {
 133    User(User),
 134    Impersonated { user: User, admin: User },
 135}
 136
 137impl Principal {
 138    fn update_span(&self, span: &tracing::Span) {
 139        match &self {
 140            Principal::User(user) => {
 141                span.record("user_id", user.id.0);
 142                span.record("login", &user.github_login);
 143            }
 144            Principal::Impersonated { user, admin } => {
 145                span.record("user_id", user.id.0);
 146                span.record("login", &user.github_login);
 147                span.record("impersonator", &admin.github_login);
 148            }
 149        }
 150    }
 151}
 152
 153#[derive(Clone)]
 154struct MessageContext {
 155    session: Session,
 156    span: tracing::Span,
 157}
 158
 159impl Deref for MessageContext {
 160    type Target = Session;
 161
 162    fn deref(&self) -> &Self::Target {
 163        &self.session
 164    }
 165}
 166
 167impl MessageContext {
 168    pub fn forward_request<T: RequestMessage>(
 169        &self,
 170        receiver_id: ConnectionId,
 171        request: T,
 172    ) -> impl Future<Output = anyhow::Result<T::Response>> {
 173        let request_start_time = Instant::now();
 174        let span = self.span.clone();
 175        tracing::info!("start forwarding request");
 176        self.peer
 177            .forward_request(self.connection_id, receiver_id, request)
 178            .inspect(move |_| {
 179                span.record(
 180                    HOST_WAITING_MS,
 181                    request_start_time.elapsed().as_micros() as f64 / 1000.0,
 182                );
 183            })
 184            .inspect_err(|_| tracing::error!("error forwarding request"))
 185            .inspect_ok(|_| tracing::info!("finished forwarding request"))
 186    }
 187}
 188
 189#[derive(Clone)]
 190struct Session {
 191    principal: Principal,
 192    connection_id: ConnectionId,
 193    db: Arc<tokio::sync::Mutex<DbHandle>>,
 194    peer: Arc<Peer>,
 195    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 196    app_state: Arc<AppState>,
 197    /// The GeoIP country code for the user.
 198    #[allow(unused)]
 199    geoip_country_code: Option<String>,
 200    #[allow(unused)]
 201    system_id: Option<String>,
 202    _executor: Executor,
 203}
 204
 205impl Session {
 206    async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
 207        #[cfg(test)]
 208        tokio::task::yield_now().await;
 209        let guard = self.db.lock().await;
 210        #[cfg(test)]
 211        tokio::task::yield_now().await;
 212        guard
 213    }
 214
 215    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 216        #[cfg(test)]
 217        tokio::task::yield_now().await;
 218        let guard = self.connection_pool.lock();
 219        ConnectionPoolGuard {
 220            guard,
 221            _not_send: PhantomData,
 222        }
 223    }
 224
 225    #[expect(dead_code)]
 226    fn is_staff(&self) -> bool {
 227        match &self.principal {
 228            Principal::User(user) => user.admin,
 229            Principal::Impersonated { .. } => true,
 230        }
 231    }
 232
 233    fn user_id(&self) -> UserId {
 234        match &self.principal {
 235            Principal::User(user) => user.id,
 236            Principal::Impersonated { user, .. } => user.id,
 237        }
 238    }
 239}
 240
 241impl Debug for Session {
 242    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 243        let mut result = f.debug_struct("Session");
 244        match &self.principal {
 245            Principal::User(user) => {
 246                result.field("user", &user.github_login);
 247            }
 248            Principal::Impersonated { user, admin } => {
 249                result.field("user", &user.github_login);
 250                result.field("impersonator", &admin.github_login);
 251            }
 252        }
 253        result.field("connection_id", &self.connection_id).finish()
 254    }
 255}
 256
 257struct DbHandle(Arc<Database>);
 258
 259impl Deref for DbHandle {
 260    type Target = Database;
 261
 262    fn deref(&self) -> &Self::Target {
 263        self.0.as_ref()
 264    }
 265}
 266
 267pub struct Server {
 268    id: parking_lot::Mutex<ServerId>,
 269    peer: Arc<Peer>,
 270    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 271    app_state: Arc<AppState>,
 272    handlers: HashMap<TypeId, MessageHandler>,
 273    teardown: watch::Sender<bool>,
 274}
 275
 276pub(crate) struct ConnectionPoolGuard<'a> {
 277    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 278    _not_send: PhantomData<Rc<()>>,
 279}
 280
 281#[derive(Serialize)]
 282pub struct ServerSnapshot<'a> {
 283    peer: &'a Peer,
 284    #[serde(serialize_with = "serialize_deref")]
 285    connection_pool: ConnectionPoolGuard<'a>,
 286}
 287
 288pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 289where
 290    S: Serializer,
 291    T: Deref<Target = U>,
 292    U: Serialize,
 293{
 294    Serialize::serialize(value.deref(), serializer)
 295}
 296
 297impl Server {
 298    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 299        let mut server = Self {
 300            id: parking_lot::Mutex::new(id),
 301            peer: Peer::new(id.0 as u32),
 302            app_state,
 303            connection_pool: Default::default(),
 304            handlers: Default::default(),
 305            teardown: watch::channel(false).0,
 306        };
 307
 308        server
 309            .add_request_handler(ping)
 310            .add_request_handler(create_room)
 311            .add_request_handler(join_room)
 312            .add_request_handler(rejoin_room)
 313            .add_request_handler(leave_room)
 314            .add_request_handler(set_room_participant_role)
 315            .add_request_handler(call)
 316            .add_request_handler(cancel_call)
 317            .add_message_handler(decline_call)
 318            .add_request_handler(update_participant_location)
 319            .add_request_handler(share_project)
 320            .add_message_handler(unshare_project)
 321            .add_request_handler(join_project)
 322            .add_message_handler(leave_project)
 323            .add_request_handler(update_project)
 324            .add_request_handler(update_worktree)
 325            .add_request_handler(update_repository)
 326            .add_request_handler(remove_repository)
 327            .add_message_handler(start_language_server)
 328            .add_message_handler(update_language_server)
 329            .add_message_handler(update_diagnostic_summary)
 330            .add_message_handler(update_worktree_settings)
 331            .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
 332            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 333            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
 334            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 335            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 336            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 337            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 338            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 339            .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
 340            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 341            .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
 342            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
 344            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 345            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 346            .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
 347            .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
 348            .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
 349            .add_request_handler(
 350                forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
 351            )
 352            .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
 353            .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
 354            .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
 355            .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
 356            .add_request_handler(
 357                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 358            )
 359            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 360            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 361            .add_request_handler(
 362                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 363            )
 364            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 365            .add_request_handler(
 366                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 367            )
 368            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 369            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 370            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 371            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 372            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 373            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 374            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 375            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 376            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 377            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 378            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 379            .add_request_handler(
 380                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 381            )
 382            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 383            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 384            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 385            .add_request_handler(lsp_query)
 386            .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
 387            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 388            .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
 389            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 390            .add_message_handler(create_buffer_for_peer)
 391            .add_message_handler(create_image_for_peer)
 392            .add_request_handler(update_buffer)
 393            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 394            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
 395            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 396            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 397            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 398            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 399            .add_message_handler(
 400                broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
 401            )
 402            .add_request_handler(get_users)
 403            .add_request_handler(fuzzy_search_users)
 404            .add_request_handler(request_contact)
 405            .add_request_handler(remove_contact)
 406            .add_request_handler(respond_to_contact_request)
 407            .add_message_handler(subscribe_to_channels)
 408            .add_request_handler(create_channel)
 409            .add_request_handler(delete_channel)
 410            .add_request_handler(invite_channel_member)
 411            .add_request_handler(remove_channel_member)
 412            .add_request_handler(set_channel_member_role)
 413            .add_request_handler(set_channel_visibility)
 414            .add_request_handler(rename_channel)
 415            .add_request_handler(join_channel_buffer)
 416            .add_request_handler(leave_channel_buffer)
 417            .add_message_handler(update_channel_buffer)
 418            .add_request_handler(rejoin_channel_buffers)
 419            .add_request_handler(get_channel_members)
 420            .add_request_handler(respond_to_channel_invite)
 421            .add_request_handler(join_channel)
 422            .add_request_handler(join_channel_chat)
 423            .add_message_handler(leave_channel_chat)
 424            .add_request_handler(send_channel_message)
 425            .add_request_handler(remove_channel_message)
 426            .add_request_handler(update_channel_message)
 427            .add_request_handler(get_channel_messages)
 428            .add_request_handler(get_channel_messages_by_id)
 429            .add_request_handler(get_notifications)
 430            .add_request_handler(mark_notification_as_read)
 431            .add_request_handler(move_channel)
 432            .add_request_handler(reorder_channel)
 433            .add_request_handler(follow)
 434            .add_message_handler(unfollow)
 435            .add_message_handler(update_followers)
 436            .add_message_handler(acknowledge_channel_message)
 437            .add_message_handler(acknowledge_buffer_version)
 438            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 439            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 440            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 441            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 442            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 443            .add_request_handler(forward_mutating_project_request::<proto::Stash>)
 444            .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
 445            .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
 446            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 447            .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
 448            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 449            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 450            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 451            .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
 452            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 453            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 454            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 455            .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
 456            .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
 457            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 458            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 459            .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
 460            .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
 461            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 462            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 463            .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
 464            .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
 465            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 466            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 467            .add_message_handler(update_context)
 468            .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
 469            .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
 470            .add_request_handler(share_agent_thread)
 471            .add_request_handler(get_shared_agent_thread);
 472
 473        Arc::new(server)
 474    }
 475
 476    pub async fn start(&self) -> Result<()> {
 477        let server_id = *self.id.lock();
 478        let app_state = self.app_state.clone();
 479        let peer = self.peer.clone();
 480        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 481        let pool = self.connection_pool.clone();
 482        let livekit_client = self.app_state.livekit_client.clone();
 483
 484        let span = info_span!("start server");
 485        self.app_state.executor.spawn_detached(
 486            async move {
 487                tracing::info!("waiting for cleanup timeout");
 488                timeout.await;
 489                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 490
 491                app_state
 492                    .db
 493                    .delete_stale_channel_chat_participants(
 494                        &app_state.config.zed_environment,
 495                        server_id,
 496                    )
 497                    .await
 498                    .trace_err();
 499
 500                if let Some((room_ids, channel_ids)) = app_state
 501                    .db
 502                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 503                    .await
 504                    .trace_err()
 505                {
 506                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 507                    tracing::info!(
 508                        stale_channel_buffer_count = channel_ids.len(),
 509                        "retrieved stale channel buffers"
 510                    );
 511
 512                    for channel_id in channel_ids {
 513                        if let Some(refreshed_channel_buffer) = app_state
 514                            .db
 515                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 516                            .await
 517                            .trace_err()
 518                        {
 519                            for connection_id in refreshed_channel_buffer.connection_ids {
 520                                peer.send(
 521                                    connection_id,
 522                                    proto::UpdateChannelBufferCollaborators {
 523                                        channel_id: channel_id.to_proto(),
 524                                        collaborators: refreshed_channel_buffer
 525                                            .collaborators
 526                                            .clone(),
 527                                    },
 528                                )
 529                                .trace_err();
 530                            }
 531                        }
 532                    }
 533
 534                    for room_id in room_ids {
 535                        let mut contacts_to_update = HashSet::default();
 536                        let mut canceled_calls_to_user_ids = Vec::new();
 537                        let mut livekit_room = String::new();
 538                        let mut delete_livekit_room = false;
 539
 540                        if let Some(mut refreshed_room) = app_state
 541                            .db
 542                            .clear_stale_room_participants(room_id, server_id)
 543                            .await
 544                            .trace_err()
 545                        {
 546                            tracing::info!(
 547                                room_id = room_id.0,
 548                                new_participant_count = refreshed_room.room.participants.len(),
 549                                "refreshed room"
 550                            );
 551                            room_updated(&refreshed_room.room, &peer);
 552                            if let Some(channel) = refreshed_room.channel.as_ref() {
 553                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 554                            }
 555                            contacts_to_update
 556                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 557                            contacts_to_update
 558                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 559                            canceled_calls_to_user_ids =
 560                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 561                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 562                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 563                        }
 564
 565                        {
 566                            let pool = pool.lock();
 567                            for canceled_user_id in canceled_calls_to_user_ids {
 568                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 569                                    peer.send(
 570                                        connection_id,
 571                                        proto::CallCanceled {
 572                                            room_id: room_id.to_proto(),
 573                                        },
 574                                    )
 575                                    .trace_err();
 576                                }
 577                            }
 578                        }
 579
 580                        for user_id in contacts_to_update {
 581                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 582                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 583                            if let Some((busy, contacts)) = busy.zip(contacts) {
 584                                let pool = pool.lock();
 585                                let updated_contact = contact_for_user(user_id, busy, &pool);
 586                                for contact in contacts {
 587                                    if let db::Contact::Accepted {
 588                                        user_id: contact_user_id,
 589                                        ..
 590                                    } = contact
 591                                    {
 592                                        for contact_conn_id in
 593                                            pool.user_connection_ids(contact_user_id)
 594                                        {
 595                                            peer.send(
 596                                                contact_conn_id,
 597                                                proto::UpdateContacts {
 598                                                    contacts: vec![updated_contact.clone()],
 599                                                    remove_contacts: Default::default(),
 600                                                    incoming_requests: Default::default(),
 601                                                    remove_incoming_requests: Default::default(),
 602                                                    outgoing_requests: Default::default(),
 603                                                    remove_outgoing_requests: Default::default(),
 604                                                },
 605                                            )
 606                                            .trace_err();
 607                                        }
 608                                    }
 609                                }
 610                            }
 611                        }
 612
 613                        if let Some(live_kit) = livekit_client.as_ref()
 614                            && delete_livekit_room
 615                        {
 616                            live_kit.delete_room(livekit_room).await.trace_err();
 617                        }
 618                    }
 619                }
 620
 621                app_state
 622                    .db
 623                    .delete_stale_channel_chat_participants(
 624                        &app_state.config.zed_environment,
 625                        server_id,
 626                    )
 627                    .await
 628                    .trace_err();
 629
 630                app_state
 631                    .db
 632                    .clear_old_worktree_entries(server_id)
 633                    .await
 634                    .trace_err();
 635
 636                app_state
 637                    .db
 638                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 639                    .await
 640                    .trace_err();
 641            }
 642            .instrument(span),
 643        );
 644        Ok(())
 645    }
 646
 647    pub fn teardown(&self) {
 648        self.peer.teardown();
 649        self.connection_pool.lock().reset();
 650        let _ = self.teardown.send(true);
 651    }
 652
 653    #[cfg(test)]
 654    pub fn reset(&self, id: ServerId) {
 655        self.teardown();
 656        *self.id.lock() = id;
 657        self.peer.reset(id.0 as u32);
 658        let _ = self.teardown.send(false);
 659    }
 660
 661    #[cfg(test)]
 662    pub fn id(&self) -> ServerId {
 663        *self.id.lock()
 664    }
 665
 666    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 667    where
 668        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
 669        Fut: 'static + Send + Future<Output = Result<()>>,
 670        M: EnvelopedMessage,
 671    {
 672        let prev_handler = self.handlers.insert(
 673            TypeId::of::<M>(),
 674            Box::new(move |envelope, session, span| {
 675                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 676                let received_at = envelope.received_at;
 677                tracing::info!("message received");
 678                let start_time = Instant::now();
 679                let future = (handler)(
 680                    *envelope,
 681                    MessageContext {
 682                        session,
 683                        span: span.clone(),
 684                    },
 685                );
 686                async move {
 687                    let result = future.await;
 688                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 689                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 690                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 691                    span.record(TOTAL_DURATION_MS, total_duration_ms);
 692                    span.record(PROCESSING_DURATION_MS, processing_duration_ms);
 693                    span.record(QUEUE_DURATION_MS, queue_duration_ms);
 694                    match result {
 695                        Err(error) => {
 696                            tracing::error!(?error, "error handling message")
 697                        }
 698                        Ok(()) => tracing::info!("finished handling message"),
 699                    }
 700                }
 701                .boxed()
 702            }),
 703        );
 704        if prev_handler.is_some() {
 705            panic!("registered a handler for the same message twice");
 706        }
 707        self
 708    }
 709
 710    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 711    where
 712        F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
 713        Fut: 'static + Send + Future<Output = Result<()>>,
 714        M: EnvelopedMessage,
 715    {
 716        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 717        self
 718    }
 719
 720    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 721    where
 722        F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
 723        Fut: Send + Future<Output = Result<()>>,
 724        M: RequestMessage,
 725    {
 726        let handler = Arc::new(handler);
 727        self.add_handler(move |envelope, session| {
 728            let receipt = envelope.receipt();
 729            let handler = handler.clone();
 730            async move {
 731                let peer = session.peer.clone();
 732                let responded = Arc::new(AtomicBool::default());
 733                let response = Response {
 734                    peer: peer.clone(),
 735                    responded: responded.clone(),
 736                    receipt,
 737                };
 738                match (handler)(envelope.payload, response, session).await {
 739                    Ok(()) => {
 740                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 741                            Ok(())
 742                        } else {
 743                            Err(anyhow!("handler did not send a response"))?
 744                        }
 745                    }
 746                    Err(error) => {
 747                        let proto_err = match &error {
 748                            Error::Internal(err) => err.to_proto(),
 749                            _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
 750                        };
 751                        peer.respond_with_error(receipt, proto_err)?;
 752                        Err(error)
 753                    }
 754                }
 755            }
 756        })
 757    }
 758
 759    pub fn handle_connection(
 760        self: &Arc<Self>,
 761        connection: Connection,
 762        address: String,
 763        principal: Principal,
 764        zed_version: ZedVersion,
 765        release_channel: Option<String>,
 766        user_agent: Option<String>,
 767        geoip_country_code: Option<String>,
 768        system_id: Option<String>,
 769        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 770        executor: Executor,
 771        connection_guard: Option<ConnectionGuard>,
 772    ) -> impl Future<Output = ()> + use<> {
 773        let this = self.clone();
 774        let span = info_span!("handle connection", %address,
 775            connection_id=field::Empty,
 776            user_id=field::Empty,
 777            login=field::Empty,
 778            impersonator=field::Empty,
 779            user_agent=field::Empty,
 780            geoip_country_code=field::Empty,
 781            release_channel=field::Empty,
 782        );
 783        principal.update_span(&span);
 784        if let Some(user_agent) = user_agent {
 785            span.record("user_agent", user_agent);
 786        }
 787        if let Some(release_channel) = release_channel {
 788            span.record("release_channel", release_channel);
 789        }
 790
 791        if let Some(country_code) = geoip_country_code.as_ref() {
 792            span.record("geoip_country_code", country_code);
 793        }
 794
 795        let mut teardown = self.teardown.subscribe();
 796        async move {
 797            if *teardown.borrow() {
 798                tracing::error!("server is tearing down");
 799                return;
 800            }
 801
 802            let (connection_id, handle_io, mut incoming_rx) =
 803                this.peer.add_connection(connection, {
 804                    let executor = executor.clone();
 805                    move |duration| executor.sleep(duration)
 806                });
 807            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 808
 809            tracing::info!("connection opened");
 810
 811            let session = Session {
 812                principal: principal.clone(),
 813                connection_id,
 814                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 815                peer: this.peer.clone(),
 816                connection_pool: this.connection_pool.clone(),
 817                app_state: this.app_state.clone(),
 818                geoip_country_code,
 819                system_id,
 820                _executor: executor.clone(),
 821            };
 822
 823            if let Err(error) = this
 824                .send_initial_client_update(
 825                    connection_id,
 826                    zed_version,
 827                    send_connection_id,
 828                    &session,
 829                )
 830                .await
 831            {
 832                tracing::error!(?error, "failed to send initial client update");
 833                return;
 834            }
 835            drop(connection_guard);
 836
 837            let handle_io = handle_io.fuse();
 838            futures::pin_mut!(handle_io);
 839
 840            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 841            // This prevents deadlocks when e.g., client A performs a request to client B and
 842            // client B performs a request to client A. If both clients stop processing further
 843            // messages until their respective request completes, they won't have a chance to
 844            // respond to the other client's request and cause a deadlock.
 845            //
 846            // This arrangement ensures we will attempt to process earlier messages first, but fall
 847            // back to processing messages arrived later in the spirit of making progress.
 848            const MAX_CONCURRENT_HANDLERS: usize = 256;
 849            let mut foreground_message_handlers = FuturesUnordered::new();
 850            let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
 851            let get_concurrent_handlers = {
 852                let concurrent_handlers = concurrent_handlers.clone();
 853                move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
 854            };
 855            loop {
 856                let next_message = async {
 857                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 858                    let message = incoming_rx.next().await;
 859                    // Cache the concurrent_handlers here, so that we know what the
 860                    // queue looks like as each handler starts
 861                    (permit, message, get_concurrent_handlers())
 862                }
 863                .fuse();
 864                futures::pin_mut!(next_message);
 865                futures::select_biased! {
 866                    _ = teardown.changed().fuse() => return,
 867                    result = handle_io => {
 868                        if let Err(error) = result {
 869                            tracing::error!(?error, "error handling I/O");
 870                        }
 871                        break;
 872                    }
 873                    _ = foreground_message_handlers.next() => {}
 874                    next_message = next_message => {
 875                        let (permit, message, concurrent_handlers) = next_message;
 876                        if let Some(message) = message {
 877                            let type_name = message.payload_type_name();
 878                            // note: we copy all the fields from the parent span so we can query them in the logs.
 879                            // (https://github.com/tokio-rs/tracing/issues/2670).
 880                            let span = tracing::info_span!("receive message",
 881                                %connection_id,
 882                                %address,
 883                                type_name,
 884                                concurrent_handlers,
 885                                user_id=field::Empty,
 886                                login=field::Empty,
 887                                impersonator=field::Empty,
 888                                lsp_query_request=field::Empty,
 889                                release_channel=field::Empty,
 890                                { TOTAL_DURATION_MS }=field::Empty,
 891                                { PROCESSING_DURATION_MS }=field::Empty,
 892                                { QUEUE_DURATION_MS }=field::Empty,
 893                                { HOST_WAITING_MS }=field::Empty
 894                            );
 895                            principal.update_span(&span);
 896                            let span_enter = span.enter();
 897                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 898                                let is_background = message.is_background();
 899                                let handle_message = (handler)(message, session.clone(), span.clone());
 900                                drop(span_enter);
 901
 902                                let handle_message = async move {
 903                                    handle_message.await;
 904                                    drop(permit);
 905                                }.instrument(span);
 906                                if is_background {
 907                                    executor.spawn_detached(handle_message);
 908                                } else {
 909                                    foreground_message_handlers.push(handle_message);
 910                                }
 911                            } else {
 912                                tracing::error!("no message handler");
 913                            }
 914                        } else {
 915                            tracing::info!("connection closed");
 916                            break;
 917                        }
 918                    }
 919                }
 920            }
 921
 922            drop(foreground_message_handlers);
 923            let concurrent_handlers = get_concurrent_handlers();
 924            tracing::info!(concurrent_handlers, "signing out");
 925            if let Err(error) = connection_lost(session, teardown, executor).await {
 926                tracing::error!(?error, "error signing out");
 927            }
 928        }
 929        .instrument(span)
 930    }
 931
 932    async fn send_initial_client_update(
 933        &self,
 934        connection_id: ConnectionId,
 935        zed_version: ZedVersion,
 936        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 937        session: &Session,
 938    ) -> Result<()> {
 939        self.peer.send(
 940            connection_id,
 941            proto::Hello {
 942                peer_id: Some(connection_id.into()),
 943            },
 944        )?;
 945        tracing::info!("sent hello message");
 946        if let Some(send_connection_id) = send_connection_id.take() {
 947            let _ = send_connection_id.send(connection_id);
 948        }
 949
 950        match &session.principal {
 951            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 952                if !user.connected_once {
 953                    self.peer.send(connection_id, proto::ShowContacts {})?;
 954                    self.app_state
 955                        .db
 956                        .set_user_connected_once(user.id, true)
 957                        .await?;
 958                }
 959
 960                let contacts = self.app_state.db.get_contacts(user.id).await?;
 961
 962                {
 963                    let mut pool = self.connection_pool.lock();
 964                    pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
 965                    self.peer.send(
 966                        connection_id,
 967                        build_initial_contacts_update(contacts, &pool),
 968                    )?;
 969                }
 970
 971                if should_auto_subscribe_to_channels(&zed_version) {
 972                    subscribe_user_to_channels(user.id, session).await?;
 973                }
 974
 975                if let Some(incoming_call) =
 976                    self.app_state.db.incoming_call_for_user(user.id).await?
 977                {
 978                    self.peer.send(connection_id, incoming_call)?;
 979                }
 980
 981                update_user_contacts(user.id, session).await?;
 982            }
 983        }
 984
 985        Ok(())
 986    }
 987
 988    pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
 989        ServerSnapshot {
 990            connection_pool: ConnectionPoolGuard {
 991                guard: self.connection_pool.lock(),
 992                _not_send: PhantomData,
 993            },
 994            peer: &self.peer,
 995        }
 996    }
 997}
 998
 999impl Deref for ConnectionPoolGuard<'_> {
1000    type Target = ConnectionPool;
1001
1002    fn deref(&self) -> &Self::Target {
1003        &self.guard
1004    }
1005}
1006
1007impl DerefMut for ConnectionPoolGuard<'_> {
1008    fn deref_mut(&mut self) -> &mut Self::Target {
1009        &mut self.guard
1010    }
1011}
1012
1013impl Drop for ConnectionPoolGuard<'_> {
1014    fn drop(&mut self) {
1015        #[cfg(test)]
1016        self.check_invariants();
1017    }
1018}
1019
1020fn broadcast<F>(
1021    sender_id: Option<ConnectionId>,
1022    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1023    mut f: F,
1024) where
1025    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1026{
1027    for receiver_id in receiver_ids {
1028        if Some(receiver_id) != sender_id
1029            && let Err(error) = f(receiver_id)
1030        {
1031            tracing::error!("failed to send to {:?} {}", receiver_id, error);
1032        }
1033    }
1034}
1035
1036pub struct ProtocolVersion(u32);
1037
1038impl Header for ProtocolVersion {
1039    fn name() -> &'static HeaderName {
1040        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1041        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1042    }
1043
1044    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1045    where
1046        Self: Sized,
1047        I: Iterator<Item = &'i axum::http::HeaderValue>,
1048    {
1049        let version = values
1050            .next()
1051            .ok_or_else(axum::headers::Error::invalid)?
1052            .to_str()
1053            .map_err(|_| axum::headers::Error::invalid())?
1054            .parse()
1055            .map_err(|_| axum::headers::Error::invalid())?;
1056        Ok(Self(version))
1057    }
1058
1059    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1060        values.extend([self.0.to_string().parse().unwrap()]);
1061    }
1062}
1063
1064pub struct AppVersionHeader(Version);
1065impl Header for AppVersionHeader {
1066    fn name() -> &'static HeaderName {
1067        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1068        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1069    }
1070
1071    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1072    where
1073        Self: Sized,
1074        I: Iterator<Item = &'i axum::http::HeaderValue>,
1075    {
1076        let version = values
1077            .next()
1078            .ok_or_else(axum::headers::Error::invalid)?
1079            .to_str()
1080            .map_err(|_| axum::headers::Error::invalid())?
1081            .parse()
1082            .map_err(|_| axum::headers::Error::invalid())?;
1083        Ok(Self(version))
1084    }
1085
1086    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1087        values.extend([self.0.to_string().parse().unwrap()]);
1088    }
1089}
1090
1091#[derive(Debug)]
1092pub struct ReleaseChannelHeader(String);
1093
1094impl Header for ReleaseChannelHeader {
1095    fn name() -> &'static HeaderName {
1096        static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1097        ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1098    }
1099
1100    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1101    where
1102        Self: Sized,
1103        I: Iterator<Item = &'i axum::http::HeaderValue>,
1104    {
1105        Ok(Self(
1106            values
1107                .next()
1108                .ok_or_else(axum::headers::Error::invalid)?
1109                .to_str()
1110                .map_err(|_| axum::headers::Error::invalid())?
1111                .to_owned(),
1112        ))
1113    }
1114
1115    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1116        values.extend([self.0.parse().unwrap()]);
1117    }
1118}
1119
1120pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1121    Router::new()
1122        .route("/rpc", get(handle_websocket_request))
1123        .layer(
1124            ServiceBuilder::new()
1125                .layer(Extension(server.app_state.clone()))
1126                .layer(middleware::from_fn(auth::validate_header)),
1127        )
1128        .route("/metrics", get(handle_metrics))
1129        .layer(Extension(server))
1130}
1131
1132pub async fn handle_websocket_request(
1133    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1134    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1135    release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1136    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1137    Extension(server): Extension<Arc<Server>>,
1138    Extension(principal): Extension<Principal>,
1139    user_agent: Option<TypedHeader<UserAgent>>,
1140    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1141    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1142    ws: WebSocketUpgrade,
1143) -> axum::response::Response {
1144    if protocol_version != rpc::PROTOCOL_VERSION {
1145        return (
1146            StatusCode::UPGRADE_REQUIRED,
1147            "client must be upgraded".to_string(),
1148        )
1149            .into_response();
1150    }
1151
1152    let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1153        return (
1154            StatusCode::UPGRADE_REQUIRED,
1155            "no version header found".to_string(),
1156        )
1157            .into_response();
1158    };
1159
1160    let release_channel = release_channel_header.map(|header| header.0.0);
1161
1162    if !version.can_collaborate() {
1163        return (
1164            StatusCode::UPGRADE_REQUIRED,
1165            "client must be upgraded".to_string(),
1166        )
1167            .into_response();
1168    }
1169
1170    let socket_address = socket_address.to_string();
1171
1172    // Acquire connection guard before WebSocket upgrade
1173    let connection_guard = match ConnectionGuard::try_acquire() {
1174        Ok(guard) => guard,
1175        Err(()) => {
1176            return (
1177                StatusCode::SERVICE_UNAVAILABLE,
1178                "Too many concurrent connections",
1179            )
1180                .into_response();
1181        }
1182    };
1183
1184    ws.on_upgrade(move |socket| {
1185        let socket = socket
1186            .map_ok(to_tungstenite_message)
1187            .err_into()
1188            .with(|message| async move { to_axum_message(message) });
1189        let connection = Connection::new(Box::pin(socket));
1190        async move {
1191            server
1192                .handle_connection(
1193                    connection,
1194                    socket_address,
1195                    principal,
1196                    version,
1197                    release_channel,
1198                    user_agent.map(|header| header.to_string()),
1199                    country_code_header.map(|header| header.to_string()),
1200                    system_id_header.map(|header| header.to_string()),
1201                    None,
1202                    Executor::Production,
1203                    Some(connection_guard),
1204                )
1205                .await;
1206        }
1207    })
1208}
1209
1210pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1211    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1212    let connections_metric = CONNECTIONS_METRIC
1213        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1214
1215    let connections = server
1216        .connection_pool
1217        .lock()
1218        .connections()
1219        .filter(|connection| !connection.admin)
1220        .count();
1221    connections_metric.set(connections as _);
1222
1223    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1224    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1225        register_int_gauge!(
1226            "shared_projects",
1227            "number of open projects with one or more guests"
1228        )
1229        .unwrap()
1230    });
1231
1232    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1233    shared_projects_metric.set(shared_projects as _);
1234
1235    let encoder = prometheus::TextEncoder::new();
1236    let metric_families = prometheus::gather();
1237    let encoded_metrics = encoder
1238        .encode_to_string(&metric_families)
1239        .map_err(|err| anyhow!("{err}"))?;
1240    Ok(encoded_metrics)
1241}
1242
1243#[instrument(err, skip(executor))]
1244async fn connection_lost(
1245    session: Session,
1246    mut teardown: watch::Receiver<bool>,
1247    executor: Executor,
1248) -> Result<()> {
1249    session.peer.disconnect(session.connection_id);
1250    session
1251        .connection_pool()
1252        .await
1253        .remove_connection(session.connection_id)?;
1254
1255    session
1256        .db()
1257        .await
1258        .connection_lost(session.connection_id)
1259        .await
1260        .trace_err();
1261
1262    futures::select_biased! {
1263        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1264
1265            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1266            leave_room_for_session(&session, session.connection_id).await.trace_err();
1267            leave_channel_buffers_for_session(&session)
1268                .await
1269                .trace_err();
1270
1271            if !session
1272                .connection_pool()
1273                .await
1274                .is_user_online(session.user_id())
1275            {
1276                let db = session.db().await;
1277                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1278                    room_updated(&room, &session.peer);
1279                }
1280            }
1281
1282            update_user_contacts(session.user_id(), &session).await?;
1283        },
1284        _ = teardown.changed().fuse() => {}
1285    }
1286
1287    Ok(())
1288}
1289
1290/// Acknowledges a ping from a client, used to keep the connection alive.
1291async fn ping(
1292    _: proto::Ping,
1293    response: Response<proto::Ping>,
1294    _session: MessageContext,
1295) -> Result<()> {
1296    response.send(proto::Ack {})?;
1297    Ok(())
1298}
1299
1300/// Creates a new room for calling (outside of channels)
1301async fn create_room(
1302    _request: proto::CreateRoom,
1303    response: Response<proto::CreateRoom>,
1304    session: MessageContext,
1305) -> Result<()> {
1306    let livekit_room = nanoid::nanoid!(30);
1307
1308    let live_kit_connection_info = util::maybe!(async {
1309        let live_kit = session.app_state.livekit_client.as_ref();
1310        let live_kit = live_kit?;
1311        let user_id = session.user_id().to_string();
1312
1313        let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1314
1315        Some(proto::LiveKitConnectionInfo {
1316            server_url: live_kit.url().into(),
1317            token,
1318            can_publish: true,
1319        })
1320    })
1321    .await;
1322
1323    let room = session
1324        .db()
1325        .await
1326        .create_room(session.user_id(), session.connection_id, &livekit_room)
1327        .await?;
1328
1329    response.send(proto::CreateRoomResponse {
1330        room: Some(room.clone()),
1331        live_kit_connection_info,
1332    })?;
1333
1334    update_user_contacts(session.user_id(), &session).await?;
1335    Ok(())
1336}
1337
1338/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1339async fn join_room(
1340    request: proto::JoinRoom,
1341    response: Response<proto::JoinRoom>,
1342    session: MessageContext,
1343) -> Result<()> {
1344    let room_id = RoomId::from_proto(request.id);
1345
1346    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1347
1348    if let Some(channel_id) = channel_id {
1349        return join_channel_internal(channel_id, Box::new(response), session).await;
1350    }
1351
1352    let joined_room = {
1353        let room = session
1354            .db()
1355            .await
1356            .join_room(room_id, session.user_id(), session.connection_id)
1357            .await?;
1358        room_updated(&room.room, &session.peer);
1359        room.into_inner()
1360    };
1361
1362    for connection_id in session
1363        .connection_pool()
1364        .await
1365        .user_connection_ids(session.user_id())
1366    {
1367        session
1368            .peer
1369            .send(
1370                connection_id,
1371                proto::CallCanceled {
1372                    room_id: room_id.to_proto(),
1373                },
1374            )
1375            .trace_err();
1376    }
1377
1378    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1379    {
1380        live_kit
1381            .room_token(
1382                &joined_room.room.livekit_room,
1383                &session.user_id().to_string(),
1384            )
1385            .trace_err()
1386            .map(|token| proto::LiveKitConnectionInfo {
1387                server_url: live_kit.url().into(),
1388                token,
1389                can_publish: true,
1390            })
1391    } else {
1392        None
1393    };
1394
1395    response.send(proto::JoinRoomResponse {
1396        room: Some(joined_room.room),
1397        channel_id: None,
1398        live_kit_connection_info,
1399    })?;
1400
1401    update_user_contacts(session.user_id(), &session).await?;
1402    Ok(())
1403}
1404
1405/// Rejoin room is used to reconnect to a room after connection errors.
1406async fn rejoin_room(
1407    request: proto::RejoinRoom,
1408    response: Response<proto::RejoinRoom>,
1409    session: MessageContext,
1410) -> Result<()> {
1411    let room;
1412    let channel;
1413    {
1414        let mut rejoined_room = session
1415            .db()
1416            .await
1417            .rejoin_room(request, session.user_id(), session.connection_id)
1418            .await?;
1419
1420        response.send(proto::RejoinRoomResponse {
1421            room: Some(rejoined_room.room.clone()),
1422            reshared_projects: rejoined_room
1423                .reshared_projects
1424                .iter()
1425                .map(|project| proto::ResharedProject {
1426                    id: project.id.to_proto(),
1427                    collaborators: project
1428                        .collaborators
1429                        .iter()
1430                        .map(|collaborator| collaborator.to_proto())
1431                        .collect(),
1432                })
1433                .collect(),
1434            rejoined_projects: rejoined_room
1435                .rejoined_projects
1436                .iter()
1437                .map(|rejoined_project| rejoined_project.to_proto())
1438                .collect(),
1439        })?;
1440        room_updated(&rejoined_room.room, &session.peer);
1441
1442        for project in &rejoined_room.reshared_projects {
1443            for collaborator in &project.collaborators {
1444                session
1445                    .peer
1446                    .send(
1447                        collaborator.connection_id,
1448                        proto::UpdateProjectCollaborator {
1449                            project_id: project.id.to_proto(),
1450                            old_peer_id: Some(project.old_connection_id.into()),
1451                            new_peer_id: Some(session.connection_id.into()),
1452                        },
1453                    )
1454                    .trace_err();
1455            }
1456
1457            broadcast(
1458                Some(session.connection_id),
1459                project
1460                    .collaborators
1461                    .iter()
1462                    .map(|collaborator| collaborator.connection_id),
1463                |connection_id| {
1464                    session.peer.forward_send(
1465                        session.connection_id,
1466                        connection_id,
1467                        proto::UpdateProject {
1468                            project_id: project.id.to_proto(),
1469                            worktrees: project.worktrees.clone(),
1470                        },
1471                    )
1472                },
1473            );
1474        }
1475
1476        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1477
1478        let rejoined_room = rejoined_room.into_inner();
1479
1480        room = rejoined_room.room;
1481        channel = rejoined_room.channel;
1482    }
1483
1484    if let Some(channel) = channel {
1485        channel_updated(
1486            &channel,
1487            &room,
1488            &session.peer,
1489            &*session.connection_pool().await,
1490        );
1491    }
1492
1493    update_user_contacts(session.user_id(), &session).await?;
1494    Ok(())
1495}
1496
1497fn notify_rejoined_projects(
1498    rejoined_projects: &mut Vec<RejoinedProject>,
1499    session: &Session,
1500) -> Result<()> {
1501    for project in rejoined_projects.iter() {
1502        for collaborator in &project.collaborators {
1503            session
1504                .peer
1505                .send(
1506                    collaborator.connection_id,
1507                    proto::UpdateProjectCollaborator {
1508                        project_id: project.id.to_proto(),
1509                        old_peer_id: Some(project.old_connection_id.into()),
1510                        new_peer_id: Some(session.connection_id.into()),
1511                    },
1512                )
1513                .trace_err();
1514        }
1515    }
1516
1517    for project in rejoined_projects {
1518        for worktree in mem::take(&mut project.worktrees) {
1519            // Stream this worktree's entries.
1520            let message = proto::UpdateWorktree {
1521                project_id: project.id.to_proto(),
1522                worktree_id: worktree.id,
1523                abs_path: worktree.abs_path.clone(),
1524                root_name: worktree.root_name,
1525                updated_entries: worktree.updated_entries,
1526                removed_entries: worktree.removed_entries,
1527                scan_id: worktree.scan_id,
1528                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1529                updated_repositories: worktree.updated_repositories,
1530                removed_repositories: worktree.removed_repositories,
1531            };
1532            for update in proto::split_worktree_update(message) {
1533                session.peer.send(session.connection_id, update)?;
1534            }
1535
1536            // Stream this worktree's diagnostics.
1537            let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1538            if let Some(summary) = worktree_diagnostics.next() {
1539                let message = proto::UpdateDiagnosticSummary {
1540                    project_id: project.id.to_proto(),
1541                    worktree_id: worktree.id,
1542                    summary: Some(summary),
1543                    more_summaries: worktree_diagnostics.collect(),
1544                };
1545                session.peer.send(session.connection_id, message)?;
1546            }
1547
1548            for settings_file in worktree.settings_files {
1549                session.peer.send(
1550                    session.connection_id,
1551                    proto::UpdateWorktreeSettings {
1552                        project_id: project.id.to_proto(),
1553                        worktree_id: worktree.id,
1554                        path: settings_file.path,
1555                        content: Some(settings_file.content),
1556                        kind: Some(settings_file.kind.to_proto().into()),
1557                    },
1558                )?;
1559            }
1560        }
1561
1562        for repository in mem::take(&mut project.updated_repositories) {
1563            for update in split_repository_update(repository) {
1564                session.peer.send(session.connection_id, update)?;
1565            }
1566        }
1567
1568        for id in mem::take(&mut project.removed_repositories) {
1569            session.peer.send(
1570                session.connection_id,
1571                proto::RemoveRepository {
1572                    project_id: project.id.to_proto(),
1573                    id,
1574                },
1575            )?;
1576        }
1577    }
1578
1579    Ok(())
1580}
1581
1582/// leave room disconnects from the room.
1583async fn leave_room(
1584    _: proto::LeaveRoom,
1585    response: Response<proto::LeaveRoom>,
1586    session: MessageContext,
1587) -> Result<()> {
1588    leave_room_for_session(&session, session.connection_id).await?;
1589    response.send(proto::Ack {})?;
1590    Ok(())
1591}
1592
1593/// Updates the permissions of someone else in the room.
1594async fn set_room_participant_role(
1595    request: proto::SetRoomParticipantRole,
1596    response: Response<proto::SetRoomParticipantRole>,
1597    session: MessageContext,
1598) -> Result<()> {
1599    let user_id = UserId::from_proto(request.user_id);
1600    let role = ChannelRole::from(request.role());
1601
1602    let (livekit_room, can_publish) = {
1603        let room = session
1604            .db()
1605            .await
1606            .set_room_participant_role(
1607                session.user_id(),
1608                RoomId::from_proto(request.room_id),
1609                user_id,
1610                role,
1611            )
1612            .await?;
1613
1614        let livekit_room = room.livekit_room.clone();
1615        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1616        room_updated(&room, &session.peer);
1617        (livekit_room, can_publish)
1618    };
1619
1620    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1621        live_kit
1622            .update_participant(
1623                livekit_room.clone(),
1624                request.user_id.to_string(),
1625                livekit_api::proto::ParticipantPermission {
1626                    can_subscribe: true,
1627                    can_publish,
1628                    can_publish_data: can_publish,
1629                    hidden: false,
1630                    recorder: false,
1631                },
1632            )
1633            .await
1634            .trace_err();
1635    }
1636
1637    response.send(proto::Ack {})?;
1638    Ok(())
1639}
1640
1641/// Call someone else into the current room
1642async fn call(
1643    request: proto::Call,
1644    response: Response<proto::Call>,
1645    session: MessageContext,
1646) -> Result<()> {
1647    let room_id = RoomId::from_proto(request.room_id);
1648    let calling_user_id = session.user_id();
1649    let calling_connection_id = session.connection_id;
1650    let called_user_id = UserId::from_proto(request.called_user_id);
1651    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1652    if !session
1653        .db()
1654        .await
1655        .has_contact(calling_user_id, called_user_id)
1656        .await?
1657    {
1658        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1659    }
1660
1661    let incoming_call = {
1662        let (room, incoming_call) = &mut *session
1663            .db()
1664            .await
1665            .call(
1666                room_id,
1667                calling_user_id,
1668                calling_connection_id,
1669                called_user_id,
1670                initial_project_id,
1671            )
1672            .await?;
1673        room_updated(room, &session.peer);
1674        mem::take(incoming_call)
1675    };
1676    update_user_contacts(called_user_id, &session).await?;
1677
1678    let mut calls = session
1679        .connection_pool()
1680        .await
1681        .user_connection_ids(called_user_id)
1682        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1683        .collect::<FuturesUnordered<_>>();
1684
1685    while let Some(call_response) = calls.next().await {
1686        match call_response.as_ref() {
1687            Ok(_) => {
1688                response.send(proto::Ack {})?;
1689                return Ok(());
1690            }
1691            Err(_) => {
1692                call_response.trace_err();
1693            }
1694        }
1695    }
1696
1697    {
1698        let room = session
1699            .db()
1700            .await
1701            .call_failed(room_id, called_user_id)
1702            .await?;
1703        room_updated(&room, &session.peer);
1704    }
1705    update_user_contacts(called_user_id, &session).await?;
1706
1707    Err(anyhow!("failed to ring user"))?
1708}
1709
1710/// Cancel an outgoing call.
1711async fn cancel_call(
1712    request: proto::CancelCall,
1713    response: Response<proto::CancelCall>,
1714    session: MessageContext,
1715) -> Result<()> {
1716    let called_user_id = UserId::from_proto(request.called_user_id);
1717    let room_id = RoomId::from_proto(request.room_id);
1718    {
1719        let room = session
1720            .db()
1721            .await
1722            .cancel_call(room_id, session.connection_id, called_user_id)
1723            .await?;
1724        room_updated(&room, &session.peer);
1725    }
1726
1727    for connection_id in session
1728        .connection_pool()
1729        .await
1730        .user_connection_ids(called_user_id)
1731    {
1732        session
1733            .peer
1734            .send(
1735                connection_id,
1736                proto::CallCanceled {
1737                    room_id: room_id.to_proto(),
1738                },
1739            )
1740            .trace_err();
1741    }
1742    response.send(proto::Ack {})?;
1743
1744    update_user_contacts(called_user_id, &session).await?;
1745    Ok(())
1746}
1747
1748/// Decline an incoming call.
1749async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1750    let room_id = RoomId::from_proto(message.room_id);
1751    {
1752        let room = session
1753            .db()
1754            .await
1755            .decline_call(Some(room_id), session.user_id())
1756            .await?
1757            .context("declining call")?;
1758        room_updated(&room, &session.peer);
1759    }
1760
1761    for connection_id in session
1762        .connection_pool()
1763        .await
1764        .user_connection_ids(session.user_id())
1765    {
1766        session
1767            .peer
1768            .send(
1769                connection_id,
1770                proto::CallCanceled {
1771                    room_id: room_id.to_proto(),
1772                },
1773            )
1774            .trace_err();
1775    }
1776    update_user_contacts(session.user_id(), &session).await?;
1777    Ok(())
1778}
1779
1780/// Updates other participants in the room with your current location.
1781async fn update_participant_location(
1782    request: proto::UpdateParticipantLocation,
1783    response: Response<proto::UpdateParticipantLocation>,
1784    session: MessageContext,
1785) -> Result<()> {
1786    let room_id = RoomId::from_proto(request.room_id);
1787    let location = request.location.context("invalid location")?;
1788
1789    let db = session.db().await;
1790    let room = db
1791        .update_room_participant_location(room_id, session.connection_id, location)
1792        .await?;
1793
1794    room_updated(&room, &session.peer);
1795    response.send(proto::Ack {})?;
1796    Ok(())
1797}
1798
1799/// Share a project into the room.
1800async fn share_project(
1801    request: proto::ShareProject,
1802    response: Response<proto::ShareProject>,
1803    session: MessageContext,
1804) -> Result<()> {
1805    let (project_id, room) = &*session
1806        .db()
1807        .await
1808        .share_project(
1809            RoomId::from_proto(request.room_id),
1810            session.connection_id,
1811            &request.worktrees,
1812            request.is_ssh_project,
1813            request.windows_paths.unwrap_or(false),
1814        )
1815        .await?;
1816    response.send(proto::ShareProjectResponse {
1817        project_id: project_id.to_proto(),
1818    })?;
1819    room_updated(room, &session.peer);
1820
1821    Ok(())
1822}
1823
1824/// Unshare a project from the room.
1825async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1826    let project_id = ProjectId::from_proto(message.project_id);
1827    unshare_project_internal(project_id, session.connection_id, &session).await
1828}
1829
1830async fn unshare_project_internal(
1831    project_id: ProjectId,
1832    connection_id: ConnectionId,
1833    session: &Session,
1834) -> Result<()> {
1835    let delete = {
1836        let room_guard = session
1837            .db()
1838            .await
1839            .unshare_project(project_id, connection_id)
1840            .await?;
1841
1842        let (delete, room, guest_connection_ids) = &*room_guard;
1843
1844        let message = proto::UnshareProject {
1845            project_id: project_id.to_proto(),
1846        };
1847
1848        broadcast(
1849            Some(connection_id),
1850            guest_connection_ids.iter().copied(),
1851            |conn_id| session.peer.send(conn_id, message.clone()),
1852        );
1853        if let Some(room) = room {
1854            room_updated(room, &session.peer);
1855        }
1856
1857        *delete
1858    };
1859
1860    if delete {
1861        let db = session.db().await;
1862        db.delete_project(project_id).await?;
1863    }
1864
1865    Ok(())
1866}
1867
1868/// Join someone elses shared project.
1869async fn join_project(
1870    request: proto::JoinProject,
1871    response: Response<proto::JoinProject>,
1872    session: MessageContext,
1873) -> Result<()> {
1874    let project_id = ProjectId::from_proto(request.project_id);
1875
1876    tracing::info!(%project_id, "join project");
1877
1878    let db = session.db().await;
1879    let (project, replica_id) = &mut *db
1880        .join_project(
1881            project_id,
1882            session.connection_id,
1883            session.user_id(),
1884            request.committer_name.clone(),
1885            request.committer_email.clone(),
1886        )
1887        .await?;
1888    drop(db);
1889    tracing::info!(%project_id, "join remote project");
1890    let collaborators = project
1891        .collaborators
1892        .iter()
1893        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1894        .map(|collaborator| collaborator.to_proto())
1895        .collect::<Vec<_>>();
1896    let project_id = project.id;
1897    let guest_user_id = session.user_id();
1898
1899    let worktrees = project
1900        .worktrees
1901        .iter()
1902        .map(|(id, worktree)| proto::WorktreeMetadata {
1903            id: *id,
1904            root_name: worktree.root_name.clone(),
1905            visible: worktree.visible,
1906            abs_path: worktree.abs_path.clone(),
1907        })
1908        .collect::<Vec<_>>();
1909
1910    let add_project_collaborator = proto::AddProjectCollaborator {
1911        project_id: project_id.to_proto(),
1912        collaborator: Some(proto::Collaborator {
1913            peer_id: Some(session.connection_id.into()),
1914            replica_id: replica_id.0 as u32,
1915            user_id: guest_user_id.to_proto(),
1916            is_host: false,
1917            committer_name: request.committer_name.clone(),
1918            committer_email: request.committer_email.clone(),
1919        }),
1920    };
1921
1922    for collaborator in &collaborators {
1923        session
1924            .peer
1925            .send(
1926                collaborator.peer_id.unwrap().into(),
1927                add_project_collaborator.clone(),
1928            )
1929            .trace_err();
1930    }
1931
1932    // First, we send the metadata associated with each worktree.
1933    let (language_servers, language_server_capabilities) = project
1934        .language_servers
1935        .clone()
1936        .into_iter()
1937        .map(|server| (server.server, server.capabilities))
1938        .unzip();
1939    response.send(proto::JoinProjectResponse {
1940        project_id: project.id.0 as u64,
1941        worktrees,
1942        replica_id: replica_id.0 as u32,
1943        collaborators,
1944        language_servers,
1945        language_server_capabilities,
1946        role: project.role.into(),
1947        windows_paths: project.path_style == PathStyle::Windows,
1948    })?;
1949
1950    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1951        // Stream this worktree's entries.
1952        let message = proto::UpdateWorktree {
1953            project_id: project_id.to_proto(),
1954            worktree_id,
1955            abs_path: worktree.abs_path.clone(),
1956            root_name: worktree.root_name,
1957            updated_entries: worktree.entries,
1958            removed_entries: Default::default(),
1959            scan_id: worktree.scan_id,
1960            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1961            updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1962            removed_repositories: Default::default(),
1963        };
1964        for update in proto::split_worktree_update(message) {
1965            session.peer.send(session.connection_id, update.clone())?;
1966        }
1967
1968        // Stream this worktree's diagnostics.
1969        let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1970        if let Some(summary) = worktree_diagnostics.next() {
1971            let message = proto::UpdateDiagnosticSummary {
1972                project_id: project.id.to_proto(),
1973                worktree_id: worktree.id,
1974                summary: Some(summary),
1975                more_summaries: worktree_diagnostics.collect(),
1976            };
1977            session.peer.send(session.connection_id, message)?;
1978        }
1979
1980        for settings_file in worktree.settings_files {
1981            session.peer.send(
1982                session.connection_id,
1983                proto::UpdateWorktreeSettings {
1984                    project_id: project_id.to_proto(),
1985                    worktree_id: worktree.id,
1986                    path: settings_file.path,
1987                    content: Some(settings_file.content),
1988                    kind: Some(settings_file.kind.to_proto() as i32),
1989                },
1990            )?;
1991        }
1992    }
1993
1994    for repository in mem::take(&mut project.repositories) {
1995        for update in split_repository_update(repository) {
1996            session.peer.send(session.connection_id, update)?;
1997        }
1998    }
1999
2000    for language_server in &project.language_servers {
2001        session.peer.send(
2002            session.connection_id,
2003            proto::UpdateLanguageServer {
2004                project_id: project_id.to_proto(),
2005                server_name: Some(language_server.server.name.clone()),
2006                language_server_id: language_server.server.id,
2007                variant: Some(
2008                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2009                        proto::LspDiskBasedDiagnosticsUpdated {},
2010                    ),
2011                ),
2012            },
2013        )?;
2014    }
2015
2016    Ok(())
2017}
2018
2019/// Leave someone elses shared project.
2020async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2021    let sender_id = session.connection_id;
2022    let project_id = ProjectId::from_proto(request.project_id);
2023    let db = session.db().await;
2024
2025    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2026    tracing::info!(
2027        %project_id,
2028        "leave project"
2029    );
2030
2031    project_left(project, &session);
2032    if let Some(room) = room {
2033        room_updated(room, &session.peer);
2034    }
2035
2036    Ok(())
2037}
2038
2039/// Updates other participants with changes to the project
2040async fn update_project(
2041    request: proto::UpdateProject,
2042    response: Response<proto::UpdateProject>,
2043    session: MessageContext,
2044) -> Result<()> {
2045    let project_id = ProjectId::from_proto(request.project_id);
2046    let (room, guest_connection_ids) = &*session
2047        .db()
2048        .await
2049        .update_project(project_id, session.connection_id, &request.worktrees)
2050        .await?;
2051    broadcast(
2052        Some(session.connection_id),
2053        guest_connection_ids.iter().copied(),
2054        |connection_id| {
2055            session
2056                .peer
2057                .forward_send(session.connection_id, connection_id, request.clone())
2058        },
2059    );
2060    if let Some(room) = room {
2061        room_updated(room, &session.peer);
2062    }
2063    response.send(proto::Ack {})?;
2064
2065    Ok(())
2066}
2067
2068/// Updates other participants with changes to the worktree
2069async fn update_worktree(
2070    request: proto::UpdateWorktree,
2071    response: Response<proto::UpdateWorktree>,
2072    session: MessageContext,
2073) -> Result<()> {
2074    let guest_connection_ids = session
2075        .db()
2076        .await
2077        .update_worktree(&request, session.connection_id)
2078        .await?;
2079
2080    broadcast(
2081        Some(session.connection_id),
2082        guest_connection_ids.iter().copied(),
2083        |connection_id| {
2084            session
2085                .peer
2086                .forward_send(session.connection_id, connection_id, request.clone())
2087        },
2088    );
2089    response.send(proto::Ack {})?;
2090    Ok(())
2091}
2092
2093async fn update_repository(
2094    request: proto::UpdateRepository,
2095    response: Response<proto::UpdateRepository>,
2096    session: MessageContext,
2097) -> Result<()> {
2098    let guest_connection_ids = session
2099        .db()
2100        .await
2101        .update_repository(&request, session.connection_id)
2102        .await?;
2103
2104    broadcast(
2105        Some(session.connection_id),
2106        guest_connection_ids.iter().copied(),
2107        |connection_id| {
2108            session
2109                .peer
2110                .forward_send(session.connection_id, connection_id, request.clone())
2111        },
2112    );
2113    response.send(proto::Ack {})?;
2114    Ok(())
2115}
2116
2117async fn remove_repository(
2118    request: proto::RemoveRepository,
2119    response: Response<proto::RemoveRepository>,
2120    session: MessageContext,
2121) -> Result<()> {
2122    let guest_connection_ids = session
2123        .db()
2124        .await
2125        .remove_repository(&request, session.connection_id)
2126        .await?;
2127
2128    broadcast(
2129        Some(session.connection_id),
2130        guest_connection_ids.iter().copied(),
2131        |connection_id| {
2132            session
2133                .peer
2134                .forward_send(session.connection_id, connection_id, request.clone())
2135        },
2136    );
2137    response.send(proto::Ack {})?;
2138    Ok(())
2139}
2140
2141/// Updates other participants with changes to the diagnostics
2142async fn update_diagnostic_summary(
2143    message: proto::UpdateDiagnosticSummary,
2144    session: MessageContext,
2145) -> Result<()> {
2146    let guest_connection_ids = session
2147        .db()
2148        .await
2149        .update_diagnostic_summary(&message, session.connection_id)
2150        .await?;
2151
2152    broadcast(
2153        Some(session.connection_id),
2154        guest_connection_ids.iter().copied(),
2155        |connection_id| {
2156            session
2157                .peer
2158                .forward_send(session.connection_id, connection_id, message.clone())
2159        },
2160    );
2161
2162    Ok(())
2163}
2164
2165/// Updates other participants with changes to the worktree settings
2166async fn update_worktree_settings(
2167    message: proto::UpdateWorktreeSettings,
2168    session: MessageContext,
2169) -> Result<()> {
2170    let guest_connection_ids = session
2171        .db()
2172        .await
2173        .update_worktree_settings(&message, session.connection_id)
2174        .await?;
2175
2176    broadcast(
2177        Some(session.connection_id),
2178        guest_connection_ids.iter().copied(),
2179        |connection_id| {
2180            session
2181                .peer
2182                .forward_send(session.connection_id, connection_id, message.clone())
2183        },
2184    );
2185
2186    Ok(())
2187}
2188
2189/// Notify other participants that a language server has started.
2190async fn start_language_server(
2191    request: proto::StartLanguageServer,
2192    session: MessageContext,
2193) -> Result<()> {
2194    let guest_connection_ids = session
2195        .db()
2196        .await
2197        .start_language_server(&request, session.connection_id)
2198        .await?;
2199
2200    broadcast(
2201        Some(session.connection_id),
2202        guest_connection_ids.iter().copied(),
2203        |connection_id| {
2204            session
2205                .peer
2206                .forward_send(session.connection_id, connection_id, request.clone())
2207        },
2208    );
2209    Ok(())
2210}
2211
2212/// Notify other participants that a language server has changed.
2213async fn update_language_server(
2214    request: proto::UpdateLanguageServer,
2215    session: MessageContext,
2216) -> Result<()> {
2217    let project_id = ProjectId::from_proto(request.project_id);
2218    let db = session.db().await;
2219
2220    if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2221        && let Some(capabilities) = update.capabilities.clone()
2222    {
2223        db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2224            .await?;
2225    }
2226
2227    let project_connection_ids = db
2228        .project_connection_ids(project_id, session.connection_id, true)
2229        .await?;
2230    broadcast(
2231        Some(session.connection_id),
2232        project_connection_ids.iter().copied(),
2233        |connection_id| {
2234            session
2235                .peer
2236                .forward_send(session.connection_id, connection_id, request.clone())
2237        },
2238    );
2239    Ok(())
2240}
2241
2242/// forward a project request to the host. These requests should be read only
2243/// as guests are allowed to send them.
2244async fn forward_read_only_project_request<T>(
2245    request: T,
2246    response: Response<T>,
2247    session: MessageContext,
2248) -> Result<()>
2249where
2250    T: EntityMessage + RequestMessage,
2251{
2252    let project_id = ProjectId::from_proto(request.remote_entity_id());
2253    let host_connection_id = session
2254        .db()
2255        .await
2256        .host_for_read_only_project_request(project_id, session.connection_id)
2257        .await?;
2258    let payload = session.forward_request(host_connection_id, request).await?;
2259    response.send(payload)?;
2260    Ok(())
2261}
2262
2263/// forward a project request to the host. These requests are disallowed
2264/// for guests.
2265async fn forward_mutating_project_request<T>(
2266    request: T,
2267    response: Response<T>,
2268    session: MessageContext,
2269) -> Result<()>
2270where
2271    T: EntityMessage + RequestMessage,
2272{
2273    let project_id = ProjectId::from_proto(request.remote_entity_id());
2274
2275    let host_connection_id = session
2276        .db()
2277        .await
2278        .host_for_mutating_project_request(project_id, session.connection_id)
2279        .await?;
2280    let payload = session.forward_request(host_connection_id, request).await?;
2281    response.send(payload)?;
2282    Ok(())
2283}
2284
2285async fn lsp_query(
2286    request: proto::LspQuery,
2287    response: Response<proto::LspQuery>,
2288    session: MessageContext,
2289) -> Result<()> {
2290    let (name, should_write) = request.query_name_and_write_permissions();
2291    tracing::Span::current().record("lsp_query_request", name);
2292    tracing::info!("lsp_query message received");
2293    if should_write {
2294        forward_mutating_project_request(request, response, session).await
2295    } else {
2296        forward_read_only_project_request(request, response, session).await
2297    }
2298}
2299
2300/// Notify other participants that a new buffer has been created
2301async fn create_buffer_for_peer(
2302    request: proto::CreateBufferForPeer,
2303    session: MessageContext,
2304) -> Result<()> {
2305    session
2306        .db()
2307        .await
2308        .check_user_is_project_host(
2309            ProjectId::from_proto(request.project_id),
2310            session.connection_id,
2311        )
2312        .await?;
2313    let peer_id = request.peer_id.context("invalid peer id")?;
2314    session
2315        .peer
2316        .forward_send(session.connection_id, peer_id.into(), request)?;
2317    Ok(())
2318}
2319
2320/// Notify other participants that a new image has been created
2321async fn create_image_for_peer(
2322    request: proto::CreateImageForPeer,
2323    session: MessageContext,
2324) -> Result<()> {
2325    session
2326        .db()
2327        .await
2328        .check_user_is_project_host(
2329            ProjectId::from_proto(request.project_id),
2330            session.connection_id,
2331        )
2332        .await?;
2333    let peer_id = request.peer_id.context("invalid peer id")?;
2334    session
2335        .peer
2336        .forward_send(session.connection_id, peer_id.into(), request)?;
2337    Ok(())
2338}
2339
2340/// Notify other participants that a buffer has been updated. This is
2341/// allowed for guests as long as the update is limited to selections.
2342async fn update_buffer(
2343    request: proto::UpdateBuffer,
2344    response: Response<proto::UpdateBuffer>,
2345    session: MessageContext,
2346) -> Result<()> {
2347    let project_id = ProjectId::from_proto(request.project_id);
2348    let mut capability = Capability::ReadOnly;
2349
2350    for op in request.operations.iter() {
2351        match op.variant {
2352            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2353            Some(_) => capability = Capability::ReadWrite,
2354        }
2355    }
2356
2357    let host = {
2358        let guard = session
2359            .db()
2360            .await
2361            .connections_for_buffer_update(project_id, session.connection_id, capability)
2362            .await?;
2363
2364        let (host, guests) = &*guard;
2365
2366        broadcast(
2367            Some(session.connection_id),
2368            guests.clone(),
2369            |connection_id| {
2370                session
2371                    .peer
2372                    .forward_send(session.connection_id, connection_id, request.clone())
2373            },
2374        );
2375
2376        *host
2377    };
2378
2379    if host != session.connection_id {
2380        session.forward_request(host, request.clone()).await?;
2381    }
2382
2383    response.send(proto::Ack {})?;
2384    Ok(())
2385}
2386
2387async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2388    let project_id = ProjectId::from_proto(message.project_id);
2389
2390    let operation = message.operation.as_ref().context("invalid operation")?;
2391    let capability = match operation.variant.as_ref() {
2392        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2393            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2394                match buffer_op.variant {
2395                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2396                        Capability::ReadOnly
2397                    }
2398                    _ => Capability::ReadWrite,
2399                }
2400            } else {
2401                Capability::ReadWrite
2402            }
2403        }
2404        Some(_) => Capability::ReadWrite,
2405        None => Capability::ReadOnly,
2406    };
2407
2408    let guard = session
2409        .db()
2410        .await
2411        .connections_for_buffer_update(project_id, session.connection_id, capability)
2412        .await?;
2413
2414    let (host, guests) = &*guard;
2415
2416    broadcast(
2417        Some(session.connection_id),
2418        guests.iter().chain([host]).copied(),
2419        |connection_id| {
2420            session
2421                .peer
2422                .forward_send(session.connection_id, connection_id, message.clone())
2423        },
2424    );
2425
2426    Ok(())
2427}
2428
2429/// Notify other participants that a project has been updated.
2430async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2431    request: T,
2432    session: MessageContext,
2433) -> Result<()> {
2434    let project_id = ProjectId::from_proto(request.remote_entity_id());
2435    let project_connection_ids = session
2436        .db()
2437        .await
2438        .project_connection_ids(project_id, session.connection_id, false)
2439        .await?;
2440
2441    broadcast(
2442        Some(session.connection_id),
2443        project_connection_ids.iter().copied(),
2444        |connection_id| {
2445            session
2446                .peer
2447                .forward_send(session.connection_id, connection_id, request.clone())
2448        },
2449    );
2450    Ok(())
2451}
2452
2453/// Start following another user in a call.
2454async fn follow(
2455    request: proto::Follow,
2456    response: Response<proto::Follow>,
2457    session: MessageContext,
2458) -> Result<()> {
2459    let room_id = RoomId::from_proto(request.room_id);
2460    let project_id = request.project_id.map(ProjectId::from_proto);
2461    let leader_id = request.leader_id.context("invalid leader id")?.into();
2462    let follower_id = session.connection_id;
2463
2464    session
2465        .db()
2466        .await
2467        .check_room_participants(room_id, leader_id, session.connection_id)
2468        .await?;
2469
2470    let response_payload = session.forward_request(leader_id, request).await?;
2471    response.send(response_payload)?;
2472
2473    if let Some(project_id) = project_id {
2474        let room = session
2475            .db()
2476            .await
2477            .follow(room_id, project_id, leader_id, follower_id)
2478            .await?;
2479        room_updated(&room, &session.peer);
2480    }
2481
2482    Ok(())
2483}
2484
2485/// Stop following another user in a call.
2486async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2487    let room_id = RoomId::from_proto(request.room_id);
2488    let project_id = request.project_id.map(ProjectId::from_proto);
2489    let leader_id = request.leader_id.context("invalid leader id")?.into();
2490    let follower_id = session.connection_id;
2491
2492    session
2493        .db()
2494        .await
2495        .check_room_participants(room_id, leader_id, session.connection_id)
2496        .await?;
2497
2498    session
2499        .peer
2500        .forward_send(session.connection_id, leader_id, request)?;
2501
2502    if let Some(project_id) = project_id {
2503        let room = session
2504            .db()
2505            .await
2506            .unfollow(room_id, project_id, leader_id, follower_id)
2507            .await?;
2508        room_updated(&room, &session.peer);
2509    }
2510
2511    Ok(())
2512}
2513
2514/// Notify everyone following you of your current location.
2515async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2516    let room_id = RoomId::from_proto(request.room_id);
2517    let database = session.db.lock().await;
2518
2519    let connection_ids = if let Some(project_id) = request.project_id {
2520        let project_id = ProjectId::from_proto(project_id);
2521        database
2522            .project_connection_ids(project_id, session.connection_id, true)
2523            .await?
2524    } else {
2525        database
2526            .room_connection_ids(room_id, session.connection_id)
2527            .await?
2528    };
2529
2530    // For now, don't send view update messages back to that view's current leader.
2531    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2532        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2533        _ => None,
2534    });
2535
2536    for connection_id in connection_ids.iter().cloned() {
2537        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2538            session
2539                .peer
2540                .forward_send(session.connection_id, connection_id, request.clone())?;
2541        }
2542    }
2543    Ok(())
2544}
2545
2546/// Get public data about users.
2547async fn get_users(
2548    request: proto::GetUsers,
2549    response: Response<proto::GetUsers>,
2550    session: MessageContext,
2551) -> Result<()> {
2552    let user_ids = request
2553        .user_ids
2554        .into_iter()
2555        .map(UserId::from_proto)
2556        .collect();
2557    let users = session
2558        .db()
2559        .await
2560        .get_users_by_ids(user_ids)
2561        .await?
2562        .into_iter()
2563        .map(|user| proto::User {
2564            id: user.id.to_proto(),
2565            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2566            github_login: user.github_login,
2567            name: user.name,
2568        })
2569        .collect();
2570    response.send(proto::UsersResponse { users })?;
2571    Ok(())
2572}
2573
2574/// Search for users (to invite) buy Github login
2575async fn fuzzy_search_users(
2576    request: proto::FuzzySearchUsers,
2577    response: Response<proto::FuzzySearchUsers>,
2578    session: MessageContext,
2579) -> Result<()> {
2580    let query = request.query;
2581    let users = match query.len() {
2582        0 => vec![],
2583        1 | 2 => session
2584            .db()
2585            .await
2586            .get_user_by_github_login(&query)
2587            .await?
2588            .into_iter()
2589            .collect(),
2590        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2591    };
2592    let users = users
2593        .into_iter()
2594        .filter(|user| user.id != session.user_id())
2595        .map(|user| proto::User {
2596            id: user.id.to_proto(),
2597            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2598            github_login: user.github_login,
2599            name: user.name,
2600        })
2601        .collect();
2602    response.send(proto::UsersResponse { users })?;
2603    Ok(())
2604}
2605
2606/// Send a contact request to another user.
2607async fn request_contact(
2608    request: proto::RequestContact,
2609    response: Response<proto::RequestContact>,
2610    session: MessageContext,
2611) -> Result<()> {
2612    let requester_id = session.user_id();
2613    let responder_id = UserId::from_proto(request.responder_id);
2614    if requester_id == responder_id {
2615        return Err(anyhow!("cannot add yourself as a contact"))?;
2616    }
2617
2618    let notifications = session
2619        .db()
2620        .await
2621        .send_contact_request(requester_id, responder_id)
2622        .await?;
2623
2624    // Update outgoing contact requests of requester
2625    let mut update = proto::UpdateContacts::default();
2626    update.outgoing_requests.push(responder_id.to_proto());
2627    for connection_id in session
2628        .connection_pool()
2629        .await
2630        .user_connection_ids(requester_id)
2631    {
2632        session.peer.send(connection_id, update.clone())?;
2633    }
2634
2635    // Update incoming contact requests of responder
2636    let mut update = proto::UpdateContacts::default();
2637    update
2638        .incoming_requests
2639        .push(proto::IncomingContactRequest {
2640            requester_id: requester_id.to_proto(),
2641        });
2642    let connection_pool = session.connection_pool().await;
2643    for connection_id in connection_pool.user_connection_ids(responder_id) {
2644        session.peer.send(connection_id, update.clone())?;
2645    }
2646
2647    send_notifications(&connection_pool, &session.peer, notifications);
2648
2649    response.send(proto::Ack {})?;
2650    Ok(())
2651}
2652
2653/// Accept or decline a contact request
2654async fn respond_to_contact_request(
2655    request: proto::RespondToContactRequest,
2656    response: Response<proto::RespondToContactRequest>,
2657    session: MessageContext,
2658) -> Result<()> {
2659    let responder_id = session.user_id();
2660    let requester_id = UserId::from_proto(request.requester_id);
2661    let db = session.db().await;
2662    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2663        db.dismiss_contact_notification(responder_id, requester_id)
2664            .await?;
2665    } else {
2666        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2667
2668        let notifications = db
2669            .respond_to_contact_request(responder_id, requester_id, accept)
2670            .await?;
2671        let requester_busy = db.is_user_busy(requester_id).await?;
2672        let responder_busy = db.is_user_busy(responder_id).await?;
2673
2674        let pool = session.connection_pool().await;
2675        // Update responder with new contact
2676        let mut update = proto::UpdateContacts::default();
2677        if accept {
2678            update
2679                .contacts
2680                .push(contact_for_user(requester_id, requester_busy, &pool));
2681        }
2682        update
2683            .remove_incoming_requests
2684            .push(requester_id.to_proto());
2685        for connection_id in pool.user_connection_ids(responder_id) {
2686            session.peer.send(connection_id, update.clone())?;
2687        }
2688
2689        // Update requester with new contact
2690        let mut update = proto::UpdateContacts::default();
2691        if accept {
2692            update
2693                .contacts
2694                .push(contact_for_user(responder_id, responder_busy, &pool));
2695        }
2696        update
2697            .remove_outgoing_requests
2698            .push(responder_id.to_proto());
2699
2700        for connection_id in pool.user_connection_ids(requester_id) {
2701            session.peer.send(connection_id, update.clone())?;
2702        }
2703
2704        send_notifications(&pool, &session.peer, notifications);
2705    }
2706
2707    response.send(proto::Ack {})?;
2708    Ok(())
2709}
2710
2711/// Remove a contact.
2712async fn remove_contact(
2713    request: proto::RemoveContact,
2714    response: Response<proto::RemoveContact>,
2715    session: MessageContext,
2716) -> Result<()> {
2717    let requester_id = session.user_id();
2718    let responder_id = UserId::from_proto(request.user_id);
2719    let db = session.db().await;
2720    let (contact_accepted, deleted_notification_id) =
2721        db.remove_contact(requester_id, responder_id).await?;
2722
2723    let pool = session.connection_pool().await;
2724    // Update outgoing contact requests of requester
2725    let mut update = proto::UpdateContacts::default();
2726    if contact_accepted {
2727        update.remove_contacts.push(responder_id.to_proto());
2728    } else {
2729        update
2730            .remove_outgoing_requests
2731            .push(responder_id.to_proto());
2732    }
2733    for connection_id in pool.user_connection_ids(requester_id) {
2734        session.peer.send(connection_id, update.clone())?;
2735    }
2736
2737    // Update incoming contact requests of responder
2738    let mut update = proto::UpdateContacts::default();
2739    if contact_accepted {
2740        update.remove_contacts.push(requester_id.to_proto());
2741    } else {
2742        update
2743            .remove_incoming_requests
2744            .push(requester_id.to_proto());
2745    }
2746    for connection_id in pool.user_connection_ids(responder_id) {
2747        session.peer.send(connection_id, update.clone())?;
2748        if let Some(notification_id) = deleted_notification_id {
2749            session.peer.send(
2750                connection_id,
2751                proto::DeleteNotification {
2752                    notification_id: notification_id.to_proto(),
2753                },
2754            )?;
2755        }
2756    }
2757
2758    response.send(proto::Ack {})?;
2759    Ok(())
2760}
2761
2762fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2763    version.0.minor < 139
2764}
2765
2766async fn subscribe_to_channels(
2767    _: proto::SubscribeToChannels,
2768    session: MessageContext,
2769) -> Result<()> {
2770    subscribe_user_to_channels(session.user_id(), &session).await?;
2771    Ok(())
2772}
2773
2774async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2775    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2776    let mut pool = session.connection_pool().await;
2777    for membership in &channels_for_user.channel_memberships {
2778        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2779    }
2780    session.peer.send(
2781        session.connection_id,
2782        build_update_user_channels(&channels_for_user),
2783    )?;
2784    session.peer.send(
2785        session.connection_id,
2786        build_channels_update(channels_for_user),
2787    )?;
2788    Ok(())
2789}
2790
2791/// Creates a new channel.
2792async fn create_channel(
2793    request: proto::CreateChannel,
2794    response: Response<proto::CreateChannel>,
2795    session: MessageContext,
2796) -> Result<()> {
2797    let db = session.db().await;
2798
2799    let parent_id = request.parent_id.map(ChannelId::from_proto);
2800    let (channel, membership) = db
2801        .create_channel(&request.name, parent_id, session.user_id())
2802        .await?;
2803
2804    let root_id = channel.root_id();
2805    let channel = Channel::from_model(channel);
2806
2807    response.send(proto::CreateChannelResponse {
2808        channel: Some(channel.to_proto()),
2809        parent_id: request.parent_id,
2810    })?;
2811
2812    let mut connection_pool = session.connection_pool().await;
2813    if let Some(membership) = membership {
2814        connection_pool.subscribe_to_channel(
2815            membership.user_id,
2816            membership.channel_id,
2817            membership.role,
2818        );
2819        let update = proto::UpdateUserChannels {
2820            channel_memberships: vec![proto::ChannelMembership {
2821                channel_id: membership.channel_id.to_proto(),
2822                role: membership.role.into(),
2823            }],
2824            ..Default::default()
2825        };
2826        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2827            session.peer.send(connection_id, update.clone())?;
2828        }
2829    }
2830
2831    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2832        if !role.can_see_channel(channel.visibility) {
2833            continue;
2834        }
2835
2836        let update = proto::UpdateChannels {
2837            channels: vec![channel.to_proto()],
2838            ..Default::default()
2839        };
2840        session.peer.send(connection_id, update.clone())?;
2841    }
2842
2843    Ok(())
2844}
2845
2846/// Delete a channel
2847async fn delete_channel(
2848    request: proto::DeleteChannel,
2849    response: Response<proto::DeleteChannel>,
2850    session: MessageContext,
2851) -> Result<()> {
2852    let db = session.db().await;
2853
2854    let channel_id = request.channel_id;
2855    let (root_channel, removed_channels) = db
2856        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2857        .await?;
2858    response.send(proto::Ack {})?;
2859
2860    // Notify members of removed channels
2861    let mut update = proto::UpdateChannels::default();
2862    update
2863        .delete_channels
2864        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2865
2866    let connection_pool = session.connection_pool().await;
2867    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2868        session.peer.send(connection_id, update.clone())?;
2869    }
2870
2871    Ok(())
2872}
2873
2874/// Invite someone to join a channel.
2875async fn invite_channel_member(
2876    request: proto::InviteChannelMember,
2877    response: Response<proto::InviteChannelMember>,
2878    session: MessageContext,
2879) -> Result<()> {
2880    let db = session.db().await;
2881    let channel_id = ChannelId::from_proto(request.channel_id);
2882    let invitee_id = UserId::from_proto(request.user_id);
2883    let InviteMemberResult {
2884        channel,
2885        notifications,
2886    } = db
2887        .invite_channel_member(
2888            channel_id,
2889            invitee_id,
2890            session.user_id(),
2891            request.role().into(),
2892        )
2893        .await?;
2894
2895    let update = proto::UpdateChannels {
2896        channel_invitations: vec![channel.to_proto()],
2897        ..Default::default()
2898    };
2899
2900    let connection_pool = session.connection_pool().await;
2901    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2902        session.peer.send(connection_id, update.clone())?;
2903    }
2904
2905    send_notifications(&connection_pool, &session.peer, notifications);
2906
2907    response.send(proto::Ack {})?;
2908    Ok(())
2909}
2910
2911/// remove someone from a channel
2912async fn remove_channel_member(
2913    request: proto::RemoveChannelMember,
2914    response: Response<proto::RemoveChannelMember>,
2915    session: MessageContext,
2916) -> Result<()> {
2917    let db = session.db().await;
2918    let channel_id = ChannelId::from_proto(request.channel_id);
2919    let member_id = UserId::from_proto(request.user_id);
2920
2921    let RemoveChannelMemberResult {
2922        membership_update,
2923        notification_id,
2924    } = db
2925        .remove_channel_member(channel_id, member_id, session.user_id())
2926        .await?;
2927
2928    let mut connection_pool = session.connection_pool().await;
2929    notify_membership_updated(
2930        &mut connection_pool,
2931        membership_update,
2932        member_id,
2933        &session.peer,
2934    );
2935    for connection_id in connection_pool.user_connection_ids(member_id) {
2936        if let Some(notification_id) = notification_id {
2937            session
2938                .peer
2939                .send(
2940                    connection_id,
2941                    proto::DeleteNotification {
2942                        notification_id: notification_id.to_proto(),
2943                    },
2944                )
2945                .trace_err();
2946        }
2947    }
2948
2949    response.send(proto::Ack {})?;
2950    Ok(())
2951}
2952
2953/// Toggle the channel between public and private.
2954/// Care is taken to maintain the invariant that public channels only descend from public channels,
2955/// (though members-only channels can appear at any point in the hierarchy).
2956async fn set_channel_visibility(
2957    request: proto::SetChannelVisibility,
2958    response: Response<proto::SetChannelVisibility>,
2959    session: MessageContext,
2960) -> Result<()> {
2961    let db = session.db().await;
2962    let channel_id = ChannelId::from_proto(request.channel_id);
2963    let visibility = request.visibility().into();
2964
2965    let channel_model = db
2966        .set_channel_visibility(channel_id, visibility, session.user_id())
2967        .await?;
2968    let root_id = channel_model.root_id();
2969    let channel = Channel::from_model(channel_model);
2970
2971    let mut connection_pool = session.connection_pool().await;
2972    for (user_id, role) in connection_pool
2973        .channel_user_ids(root_id)
2974        .collect::<Vec<_>>()
2975        .into_iter()
2976    {
2977        let update = if role.can_see_channel(channel.visibility) {
2978            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2979            proto::UpdateChannels {
2980                channels: vec![channel.to_proto()],
2981                ..Default::default()
2982            }
2983        } else {
2984            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2985            proto::UpdateChannels {
2986                delete_channels: vec![channel.id.to_proto()],
2987                ..Default::default()
2988            }
2989        };
2990
2991        for connection_id in connection_pool.user_connection_ids(user_id) {
2992            session.peer.send(connection_id, update.clone())?;
2993        }
2994    }
2995
2996    response.send(proto::Ack {})?;
2997    Ok(())
2998}
2999
3000/// Alter the role for a user in the channel.
3001async fn set_channel_member_role(
3002    request: proto::SetChannelMemberRole,
3003    response: Response<proto::SetChannelMemberRole>,
3004    session: MessageContext,
3005) -> Result<()> {
3006    let db = session.db().await;
3007    let channel_id = ChannelId::from_proto(request.channel_id);
3008    let member_id = UserId::from_proto(request.user_id);
3009    let result = db
3010        .set_channel_member_role(
3011            channel_id,
3012            session.user_id(),
3013            member_id,
3014            request.role().into(),
3015        )
3016        .await?;
3017
3018    match result {
3019        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3020            let mut connection_pool = session.connection_pool().await;
3021            notify_membership_updated(
3022                &mut connection_pool,
3023                membership_update,
3024                member_id,
3025                &session.peer,
3026            )
3027        }
3028        db::SetMemberRoleResult::InviteUpdated(channel) => {
3029            let update = proto::UpdateChannels {
3030                channel_invitations: vec![channel.to_proto()],
3031                ..Default::default()
3032            };
3033
3034            for connection_id in session
3035                .connection_pool()
3036                .await
3037                .user_connection_ids(member_id)
3038            {
3039                session.peer.send(connection_id, update.clone())?;
3040            }
3041        }
3042    }
3043
3044    response.send(proto::Ack {})?;
3045    Ok(())
3046}
3047
3048/// Change the name of a channel
3049async fn rename_channel(
3050    request: proto::RenameChannel,
3051    response: Response<proto::RenameChannel>,
3052    session: MessageContext,
3053) -> Result<()> {
3054    let db = session.db().await;
3055    let channel_id = ChannelId::from_proto(request.channel_id);
3056    let channel_model = db
3057        .rename_channel(channel_id, session.user_id(), &request.name)
3058        .await?;
3059    let root_id = channel_model.root_id();
3060    let channel = Channel::from_model(channel_model);
3061
3062    response.send(proto::RenameChannelResponse {
3063        channel: Some(channel.to_proto()),
3064    })?;
3065
3066    let connection_pool = session.connection_pool().await;
3067    let update = proto::UpdateChannels {
3068        channels: vec![channel.to_proto()],
3069        ..Default::default()
3070    };
3071    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3072        if role.can_see_channel(channel.visibility) {
3073            session.peer.send(connection_id, update.clone())?;
3074        }
3075    }
3076
3077    Ok(())
3078}
3079
3080/// Move a channel to a new parent.
3081async fn move_channel(
3082    request: proto::MoveChannel,
3083    response: Response<proto::MoveChannel>,
3084    session: MessageContext,
3085) -> Result<()> {
3086    let channel_id = ChannelId::from_proto(request.channel_id);
3087    let to = ChannelId::from_proto(request.to);
3088
3089    let (root_id, channels) = session
3090        .db()
3091        .await
3092        .move_channel(channel_id, to, session.user_id())
3093        .await?;
3094
3095    let connection_pool = session.connection_pool().await;
3096    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3097        let channels = channels
3098            .iter()
3099            .filter_map(|channel| {
3100                if role.can_see_channel(channel.visibility) {
3101                    Some(channel.to_proto())
3102                } else {
3103                    None
3104                }
3105            })
3106            .collect::<Vec<_>>();
3107        if channels.is_empty() {
3108            continue;
3109        }
3110
3111        let update = proto::UpdateChannels {
3112            channels,
3113            ..Default::default()
3114        };
3115
3116        session.peer.send(connection_id, update.clone())?;
3117    }
3118
3119    response.send(Ack {})?;
3120    Ok(())
3121}
3122
3123async fn reorder_channel(
3124    request: proto::ReorderChannel,
3125    response: Response<proto::ReorderChannel>,
3126    session: MessageContext,
3127) -> Result<()> {
3128    let channel_id = ChannelId::from_proto(request.channel_id);
3129    let direction = request.direction();
3130
3131    let updated_channels = session
3132        .db()
3133        .await
3134        .reorder_channel(channel_id, direction, session.user_id())
3135        .await?;
3136
3137    if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3138        let connection_pool = session.connection_pool().await;
3139        for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3140            let channels = updated_channels
3141                .iter()
3142                .filter_map(|channel| {
3143                    if role.can_see_channel(channel.visibility) {
3144                        Some(channel.to_proto())
3145                    } else {
3146                        None
3147                    }
3148                })
3149                .collect::<Vec<_>>();
3150
3151            if channels.is_empty() {
3152                continue;
3153            }
3154
3155            let update = proto::UpdateChannels {
3156                channels,
3157                ..Default::default()
3158            };
3159
3160            session.peer.send(connection_id, update.clone())?;
3161        }
3162    }
3163
3164    response.send(Ack {})?;
3165    Ok(())
3166}
3167
3168/// Get the list of channel members
3169async fn get_channel_members(
3170    request: proto::GetChannelMembers,
3171    response: Response<proto::GetChannelMembers>,
3172    session: MessageContext,
3173) -> Result<()> {
3174    let db = session.db().await;
3175    let channel_id = ChannelId::from_proto(request.channel_id);
3176    let limit = if request.limit == 0 {
3177        u16::MAX as u64
3178    } else {
3179        request.limit
3180    };
3181    let (members, users) = db
3182        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3183        .await?;
3184    response.send(proto::GetChannelMembersResponse { members, users })?;
3185    Ok(())
3186}
3187
3188/// Accept or decline a channel invitation.
3189async fn respond_to_channel_invite(
3190    request: proto::RespondToChannelInvite,
3191    response: Response<proto::RespondToChannelInvite>,
3192    session: MessageContext,
3193) -> Result<()> {
3194    let db = session.db().await;
3195    let channel_id = ChannelId::from_proto(request.channel_id);
3196    let RespondToChannelInvite {
3197        membership_update,
3198        notifications,
3199    } = db
3200        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3201        .await?;
3202
3203    let mut connection_pool = session.connection_pool().await;
3204    if let Some(membership_update) = membership_update {
3205        notify_membership_updated(
3206            &mut connection_pool,
3207            membership_update,
3208            session.user_id(),
3209            &session.peer,
3210        );
3211    } else {
3212        let update = proto::UpdateChannels {
3213            remove_channel_invitations: vec![channel_id.to_proto()],
3214            ..Default::default()
3215        };
3216
3217        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3218            session.peer.send(connection_id, update.clone())?;
3219        }
3220    };
3221
3222    send_notifications(&connection_pool, &session.peer, notifications);
3223
3224    response.send(proto::Ack {})?;
3225
3226    Ok(())
3227}
3228
3229/// Join the channels' room
3230async fn join_channel(
3231    request: proto::JoinChannel,
3232    response: Response<proto::JoinChannel>,
3233    session: MessageContext,
3234) -> Result<()> {
3235    let channel_id = ChannelId::from_proto(request.channel_id);
3236    join_channel_internal(channel_id, Box::new(response), session).await
3237}
3238
3239trait JoinChannelInternalResponse {
3240    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3241}
3242impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3243    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3244        Response::<proto::JoinChannel>::send(self, result)
3245    }
3246}
3247impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3248    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3249        Response::<proto::JoinRoom>::send(self, result)
3250    }
3251}
3252
3253async fn join_channel_internal(
3254    channel_id: ChannelId,
3255    response: Box<impl JoinChannelInternalResponse>,
3256    session: MessageContext,
3257) -> Result<()> {
3258    let joined_room = {
3259        let mut db = session.db().await;
3260        // If zed quits without leaving the room, and the user re-opens zed before the
3261        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3262        // room they were in.
3263        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3264            tracing::info!(
3265                stale_connection_id = %connection,
3266                "cleaning up stale connection",
3267            );
3268            drop(db);
3269            leave_room_for_session(&session, connection).await?;
3270            db = session.db().await;
3271        }
3272
3273        let (joined_room, membership_updated, role) = db
3274            .join_channel(channel_id, session.user_id(), session.connection_id)
3275            .await?;
3276
3277        let live_kit_connection_info =
3278            session
3279                .app_state
3280                .livekit_client
3281                .as_ref()
3282                .and_then(|live_kit| {
3283                    let (can_publish, token) = if role == ChannelRole::Guest {
3284                        (
3285                            false,
3286                            live_kit
3287                                .guest_token(
3288                                    &joined_room.room.livekit_room,
3289                                    &session.user_id().to_string(),
3290                                )
3291                                .trace_err()?,
3292                        )
3293                    } else {
3294                        (
3295                            true,
3296                            live_kit
3297                                .room_token(
3298                                    &joined_room.room.livekit_room,
3299                                    &session.user_id().to_string(),
3300                                )
3301                                .trace_err()?,
3302                        )
3303                    };
3304
3305                    Some(LiveKitConnectionInfo {
3306                        server_url: live_kit.url().into(),
3307                        token,
3308                        can_publish,
3309                    })
3310                });
3311
3312        response.send(proto::JoinRoomResponse {
3313            room: Some(joined_room.room.clone()),
3314            channel_id: joined_room
3315                .channel
3316                .as_ref()
3317                .map(|channel| channel.id.to_proto()),
3318            live_kit_connection_info,
3319        })?;
3320
3321        let mut connection_pool = session.connection_pool().await;
3322        if let Some(membership_updated) = membership_updated {
3323            notify_membership_updated(
3324                &mut connection_pool,
3325                membership_updated,
3326                session.user_id(),
3327                &session.peer,
3328            );
3329        }
3330
3331        room_updated(&joined_room.room, &session.peer);
3332
3333        joined_room
3334    };
3335
3336    channel_updated(
3337        &joined_room.channel.context("channel not returned")?,
3338        &joined_room.room,
3339        &session.peer,
3340        &*session.connection_pool().await,
3341    );
3342
3343    update_user_contacts(session.user_id(), &session).await?;
3344    Ok(())
3345}
3346
3347/// Start editing the channel notes
3348async fn join_channel_buffer(
3349    request: proto::JoinChannelBuffer,
3350    response: Response<proto::JoinChannelBuffer>,
3351    session: MessageContext,
3352) -> Result<()> {
3353    let db = session.db().await;
3354    let channel_id = ChannelId::from_proto(request.channel_id);
3355
3356    let open_response = db
3357        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3358        .await?;
3359
3360    let collaborators = open_response.collaborators.clone();
3361    response.send(open_response)?;
3362
3363    let update = UpdateChannelBufferCollaborators {
3364        channel_id: channel_id.to_proto(),
3365        collaborators: collaborators.clone(),
3366    };
3367    channel_buffer_updated(
3368        session.connection_id,
3369        collaborators
3370            .iter()
3371            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3372        &update,
3373        &session.peer,
3374    );
3375
3376    Ok(())
3377}
3378
3379/// Edit the channel notes
3380async fn update_channel_buffer(
3381    request: proto::UpdateChannelBuffer,
3382    session: MessageContext,
3383) -> Result<()> {
3384    let db = session.db().await;
3385    let channel_id = ChannelId::from_proto(request.channel_id);
3386
3387    let (collaborators, epoch, version) = db
3388        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3389        .await?;
3390
3391    channel_buffer_updated(
3392        session.connection_id,
3393        collaborators.clone(),
3394        &proto::UpdateChannelBuffer {
3395            channel_id: channel_id.to_proto(),
3396            operations: request.operations,
3397        },
3398        &session.peer,
3399    );
3400
3401    let pool = &*session.connection_pool().await;
3402
3403    let non_collaborators =
3404        pool.channel_connection_ids(channel_id)
3405            .filter_map(|(connection_id, _)| {
3406                if collaborators.contains(&connection_id) {
3407                    None
3408                } else {
3409                    Some(connection_id)
3410                }
3411            });
3412
3413    broadcast(None, non_collaborators, |peer_id| {
3414        session.peer.send(
3415            peer_id,
3416            proto::UpdateChannels {
3417                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3418                    channel_id: channel_id.to_proto(),
3419                    epoch: epoch as u64,
3420                    version: version.clone(),
3421                }],
3422                ..Default::default()
3423            },
3424        )
3425    });
3426
3427    Ok(())
3428}
3429
3430/// Rejoin the channel notes after a connection blip
3431async fn rejoin_channel_buffers(
3432    request: proto::RejoinChannelBuffers,
3433    response: Response<proto::RejoinChannelBuffers>,
3434    session: MessageContext,
3435) -> Result<()> {
3436    let db = session.db().await;
3437    let buffers = db
3438        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3439        .await?;
3440
3441    for rejoined_buffer in &buffers {
3442        let collaborators_to_notify = rejoined_buffer
3443            .buffer
3444            .collaborators
3445            .iter()
3446            .filter_map(|c| Some(c.peer_id?.into()));
3447        channel_buffer_updated(
3448            session.connection_id,
3449            collaborators_to_notify,
3450            &proto::UpdateChannelBufferCollaborators {
3451                channel_id: rejoined_buffer.buffer.channel_id,
3452                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3453            },
3454            &session.peer,
3455        );
3456    }
3457
3458    response.send(proto::RejoinChannelBuffersResponse {
3459        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3460    })?;
3461
3462    Ok(())
3463}
3464
3465/// Stop editing the channel notes
3466async fn leave_channel_buffer(
3467    request: proto::LeaveChannelBuffer,
3468    response: Response<proto::LeaveChannelBuffer>,
3469    session: MessageContext,
3470) -> Result<()> {
3471    let db = session.db().await;
3472    let channel_id = ChannelId::from_proto(request.channel_id);
3473
3474    let left_buffer = db
3475        .leave_channel_buffer(channel_id, session.connection_id)
3476        .await?;
3477
3478    response.send(Ack {})?;
3479
3480    channel_buffer_updated(
3481        session.connection_id,
3482        left_buffer.connections,
3483        &proto::UpdateChannelBufferCollaborators {
3484            channel_id: channel_id.to_proto(),
3485            collaborators: left_buffer.collaborators,
3486        },
3487        &session.peer,
3488    );
3489
3490    Ok(())
3491}
3492
3493fn channel_buffer_updated<T: EnvelopedMessage>(
3494    sender_id: ConnectionId,
3495    collaborators: impl IntoIterator<Item = ConnectionId>,
3496    message: &T,
3497    peer: &Peer,
3498) {
3499    broadcast(Some(sender_id), collaborators, |peer_id| {
3500        peer.send(peer_id, message.clone())
3501    });
3502}
3503
3504fn send_notifications(
3505    connection_pool: &ConnectionPool,
3506    peer: &Peer,
3507    notifications: db::NotificationBatch,
3508) {
3509    for (user_id, notification) in notifications {
3510        for connection_id in connection_pool.user_connection_ids(user_id) {
3511            if let Err(error) = peer.send(
3512                connection_id,
3513                proto::AddNotification {
3514                    notification: Some(notification.clone()),
3515                },
3516            ) {
3517                tracing::error!(
3518                    "failed to send notification to {:?} {}",
3519                    connection_id,
3520                    error
3521                );
3522            }
3523        }
3524    }
3525}
3526
3527/// Send a message to the channel
3528async fn send_channel_message(
3529    _request: proto::SendChannelMessage,
3530    _response: Response<proto::SendChannelMessage>,
3531    _session: MessageContext,
3532) -> Result<()> {
3533    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3534}
3535
3536/// Delete a channel message
3537async fn remove_channel_message(
3538    _request: proto::RemoveChannelMessage,
3539    _response: Response<proto::RemoveChannelMessage>,
3540    _session: MessageContext,
3541) -> Result<()> {
3542    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3543}
3544
3545async fn update_channel_message(
3546    _request: proto::UpdateChannelMessage,
3547    _response: Response<proto::UpdateChannelMessage>,
3548    _session: MessageContext,
3549) -> Result<()> {
3550    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3551}
3552
3553/// Mark a channel message as read
3554async fn acknowledge_channel_message(
3555    _request: proto::AckChannelMessage,
3556    _session: MessageContext,
3557) -> Result<()> {
3558    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3559}
3560
3561/// Mark a buffer version as synced
3562async fn acknowledge_buffer_version(
3563    request: proto::AckBufferOperation,
3564    session: MessageContext,
3565) -> Result<()> {
3566    let buffer_id = BufferId::from_proto(request.buffer_id);
3567    session
3568        .db()
3569        .await
3570        .observe_buffer_version(
3571            buffer_id,
3572            session.user_id(),
3573            request.epoch as i32,
3574            &request.version,
3575        )
3576        .await?;
3577    Ok(())
3578}
3579
3580/// Start receiving chat updates for a channel
3581async fn join_channel_chat(
3582    _request: proto::JoinChannelChat,
3583    _response: Response<proto::JoinChannelChat>,
3584    _session: MessageContext,
3585) -> Result<()> {
3586    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3587}
3588
3589/// Stop receiving chat updates for a channel
3590async fn leave_channel_chat(
3591    _request: proto::LeaveChannelChat,
3592    _session: MessageContext,
3593) -> Result<()> {
3594    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3595}
3596
3597/// Retrieve the chat history for a channel
3598async fn get_channel_messages(
3599    _request: proto::GetChannelMessages,
3600    _response: Response<proto::GetChannelMessages>,
3601    _session: MessageContext,
3602) -> Result<()> {
3603    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3604}
3605
3606/// Retrieve specific chat messages
3607async fn get_channel_messages_by_id(
3608    _request: proto::GetChannelMessagesById,
3609    _response: Response<proto::GetChannelMessagesById>,
3610    _session: MessageContext,
3611) -> Result<()> {
3612    Err(anyhow!("chat has been removed in the latest version of Zed").into())
3613}
3614
3615/// Retrieve the current users notifications
3616async fn get_notifications(
3617    request: proto::GetNotifications,
3618    response: Response<proto::GetNotifications>,
3619    session: MessageContext,
3620) -> Result<()> {
3621    let notifications = session
3622        .db()
3623        .await
3624        .get_notifications(
3625            session.user_id(),
3626            NOTIFICATION_COUNT_PER_PAGE,
3627            request.before_id.map(db::NotificationId::from_proto),
3628        )
3629        .await?;
3630    response.send(proto::GetNotificationsResponse {
3631        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3632        notifications,
3633    })?;
3634    Ok(())
3635}
3636
3637/// Mark notifications as read
3638async fn mark_notification_as_read(
3639    request: proto::MarkNotificationRead,
3640    response: Response<proto::MarkNotificationRead>,
3641    session: MessageContext,
3642) -> Result<()> {
3643    let database = &session.db().await;
3644    let notifications = database
3645        .mark_notification_as_read_by_id(
3646            session.user_id(),
3647            NotificationId::from_proto(request.notification_id),
3648        )
3649        .await?;
3650    send_notifications(
3651        &*session.connection_pool().await,
3652        &session.peer,
3653        notifications,
3654    );
3655    response.send(proto::Ack {})?;
3656    Ok(())
3657}
3658
3659fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3660    let message = match message {
3661        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3662        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3663        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3664        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3665        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3666            code: frame.code.into(),
3667            reason: frame.reason.as_str().to_owned().into(),
3668        })),
3669        // We should never receive a frame while reading the message, according
3670        // to the `tungstenite` maintainers:
3671        //
3672        // > It cannot occur when you read messages from the WebSocket, but it
3673        // > can be used when you want to send the raw frames (e.g. you want to
3674        // > send the frames to the WebSocket without composing the full message first).
3675        // >
3676        // > — https://github.com/snapview/tungstenite-rs/issues/268
3677        TungsteniteMessage::Frame(_) => {
3678            bail!("received an unexpected frame while reading the message")
3679        }
3680    };
3681
3682    Ok(message)
3683}
3684
3685fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3686    match message {
3687        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3688        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3689        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3690        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3691        AxumMessage::Close(frame) => {
3692            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3693                code: frame.code.into(),
3694                reason: frame.reason.as_ref().into(),
3695            }))
3696        }
3697    }
3698}
3699
3700fn notify_membership_updated(
3701    connection_pool: &mut ConnectionPool,
3702    result: MembershipUpdated,
3703    user_id: UserId,
3704    peer: &Peer,
3705) {
3706    for membership in &result.new_channels.channel_memberships {
3707        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3708    }
3709    for channel_id in &result.removed_channels {
3710        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3711    }
3712
3713    let user_channels_update = proto::UpdateUserChannels {
3714        channel_memberships: result
3715            .new_channels
3716            .channel_memberships
3717            .iter()
3718            .map(|cm| proto::ChannelMembership {
3719                channel_id: cm.channel_id.to_proto(),
3720                role: cm.role.into(),
3721            })
3722            .collect(),
3723        ..Default::default()
3724    };
3725
3726    let mut update = build_channels_update(result.new_channels);
3727    update.delete_channels = result
3728        .removed_channels
3729        .into_iter()
3730        .map(|id| id.to_proto())
3731        .collect();
3732    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3733
3734    for connection_id in connection_pool.user_connection_ids(user_id) {
3735        peer.send(connection_id, user_channels_update.clone())
3736            .trace_err();
3737        peer.send(connection_id, update.clone()).trace_err();
3738    }
3739}
3740
3741fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3742    proto::UpdateUserChannels {
3743        channel_memberships: channels
3744            .channel_memberships
3745            .iter()
3746            .map(|m| proto::ChannelMembership {
3747                channel_id: m.channel_id.to_proto(),
3748                role: m.role.into(),
3749            })
3750            .collect(),
3751        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3752    }
3753}
3754
3755fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3756    let mut update = proto::UpdateChannels::default();
3757
3758    for channel in channels.channels {
3759        update.channels.push(channel.to_proto());
3760    }
3761
3762    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3763
3764    for (channel_id, participants) in channels.channel_participants {
3765        update
3766            .channel_participants
3767            .push(proto::ChannelParticipants {
3768                channel_id: channel_id.to_proto(),
3769                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3770            });
3771    }
3772
3773    for channel in channels.invited_channels {
3774        update.channel_invitations.push(channel.to_proto());
3775    }
3776
3777    update
3778}
3779
3780fn build_initial_contacts_update(
3781    contacts: Vec<db::Contact>,
3782    pool: &ConnectionPool,
3783) -> proto::UpdateContacts {
3784    let mut update = proto::UpdateContacts::default();
3785
3786    for contact in contacts {
3787        match contact {
3788            db::Contact::Accepted { user_id, busy } => {
3789                update.contacts.push(contact_for_user(user_id, busy, pool));
3790            }
3791            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3792            db::Contact::Incoming { user_id } => {
3793                update
3794                    .incoming_requests
3795                    .push(proto::IncomingContactRequest {
3796                        requester_id: user_id.to_proto(),
3797                    })
3798            }
3799        }
3800    }
3801
3802    update
3803}
3804
3805fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3806    proto::Contact {
3807        user_id: user_id.to_proto(),
3808        online: pool.is_user_online(user_id),
3809        busy,
3810    }
3811}
3812
3813fn room_updated(room: &proto::Room, peer: &Peer) {
3814    broadcast(
3815        None,
3816        room.participants
3817            .iter()
3818            .filter_map(|participant| Some(participant.peer_id?.into())),
3819        |peer_id| {
3820            peer.send(
3821                peer_id,
3822                proto::RoomUpdated {
3823                    room: Some(room.clone()),
3824                },
3825            )
3826        },
3827    );
3828}
3829
3830fn channel_updated(
3831    channel: &db::channel::Model,
3832    room: &proto::Room,
3833    peer: &Peer,
3834    pool: &ConnectionPool,
3835) {
3836    let participants = room
3837        .participants
3838        .iter()
3839        .map(|p| p.user_id)
3840        .collect::<Vec<_>>();
3841
3842    broadcast(
3843        None,
3844        pool.channel_connection_ids(channel.root_id())
3845            .filter_map(|(channel_id, role)| {
3846                role.can_see_channel(channel.visibility)
3847                    .then_some(channel_id)
3848            }),
3849        |peer_id| {
3850            peer.send(
3851                peer_id,
3852                proto::UpdateChannels {
3853                    channel_participants: vec![proto::ChannelParticipants {
3854                        channel_id: channel.id.to_proto(),
3855                        participant_user_ids: participants.clone(),
3856                    }],
3857                    ..Default::default()
3858                },
3859            )
3860        },
3861    );
3862}
3863
3864async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3865    let db = session.db().await;
3866
3867    let contacts = db.get_contacts(user_id).await?;
3868    let busy = db.is_user_busy(user_id).await?;
3869
3870    let pool = session.connection_pool().await;
3871    let updated_contact = contact_for_user(user_id, busy, &pool);
3872    for contact in contacts {
3873        if let db::Contact::Accepted {
3874            user_id: contact_user_id,
3875            ..
3876        } = contact
3877        {
3878            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3879                session
3880                    .peer
3881                    .send(
3882                        contact_conn_id,
3883                        proto::UpdateContacts {
3884                            contacts: vec![updated_contact.clone()],
3885                            remove_contacts: Default::default(),
3886                            incoming_requests: Default::default(),
3887                            remove_incoming_requests: Default::default(),
3888                            outgoing_requests: Default::default(),
3889                            remove_outgoing_requests: Default::default(),
3890                        },
3891                    )
3892                    .trace_err();
3893            }
3894        }
3895    }
3896    Ok(())
3897}
3898
3899async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3900    let mut contacts_to_update = HashSet::default();
3901
3902    let room_id;
3903    let canceled_calls_to_user_ids;
3904    let livekit_room;
3905    let delete_livekit_room;
3906    let room;
3907    let channel;
3908
3909    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3910        contacts_to_update.insert(session.user_id());
3911
3912        for project in left_room.left_projects.values() {
3913            project_left(project, session);
3914        }
3915
3916        room_id = RoomId::from_proto(left_room.room.id);
3917        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3918        livekit_room = mem::take(&mut left_room.room.livekit_room);
3919        delete_livekit_room = left_room.deleted;
3920        room = mem::take(&mut left_room.room);
3921        channel = mem::take(&mut left_room.channel);
3922
3923        room_updated(&room, &session.peer);
3924    } else {
3925        return Ok(());
3926    }
3927
3928    if let Some(channel) = channel {
3929        channel_updated(
3930            &channel,
3931            &room,
3932            &session.peer,
3933            &*session.connection_pool().await,
3934        );
3935    }
3936
3937    {
3938        let pool = session.connection_pool().await;
3939        for canceled_user_id in canceled_calls_to_user_ids {
3940            for connection_id in pool.user_connection_ids(canceled_user_id) {
3941                session
3942                    .peer
3943                    .send(
3944                        connection_id,
3945                        proto::CallCanceled {
3946                            room_id: room_id.to_proto(),
3947                        },
3948                    )
3949                    .trace_err();
3950            }
3951            contacts_to_update.insert(canceled_user_id);
3952        }
3953    }
3954
3955    for contact_user_id in contacts_to_update {
3956        update_user_contacts(contact_user_id, session).await?;
3957    }
3958
3959    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3960        live_kit
3961            .remove_participant(livekit_room.clone(), session.user_id().to_string())
3962            .await
3963            .trace_err();
3964
3965        if delete_livekit_room {
3966            live_kit.delete_room(livekit_room).await.trace_err();
3967        }
3968    }
3969
3970    Ok(())
3971}
3972
3973async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3974    let left_channel_buffers = session
3975        .db()
3976        .await
3977        .leave_channel_buffers(session.connection_id)
3978        .await?;
3979
3980    for left_buffer in left_channel_buffers {
3981        channel_buffer_updated(
3982            session.connection_id,
3983            left_buffer.connections,
3984            &proto::UpdateChannelBufferCollaborators {
3985                channel_id: left_buffer.channel_id.to_proto(),
3986                collaborators: left_buffer.collaborators,
3987            },
3988            &session.peer,
3989        );
3990    }
3991
3992    Ok(())
3993}
3994
3995fn project_left(project: &db::LeftProject, session: &Session) {
3996    for connection_id in &project.connection_ids {
3997        if project.should_unshare {
3998            session
3999                .peer
4000                .send(
4001                    *connection_id,
4002                    proto::UnshareProject {
4003                        project_id: project.id.to_proto(),
4004                    },
4005                )
4006                .trace_err();
4007        } else {
4008            session
4009                .peer
4010                .send(
4011                    *connection_id,
4012                    proto::RemoveProjectCollaborator {
4013                        project_id: project.id.to_proto(),
4014                        peer_id: Some(session.connection_id.into()),
4015                    },
4016                )
4017                .trace_err();
4018        }
4019    }
4020}
4021
4022async fn share_agent_thread(
4023    request: proto::ShareAgentThread,
4024    response: Response<proto::ShareAgentThread>,
4025    session: MessageContext,
4026) -> Result<()> {
4027    let user_id = session.user_id();
4028
4029    let share_id = SharedThreadId::from_proto(request.session_id.clone())
4030        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4031
4032    session
4033        .db()
4034        .await
4035        .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4036        .await?;
4037
4038    response.send(proto::Ack {})?;
4039
4040    Ok(())
4041}
4042
4043async fn get_shared_agent_thread(
4044    request: proto::GetSharedAgentThread,
4045    response: Response<proto::GetSharedAgentThread>,
4046    session: MessageContext,
4047) -> Result<()> {
4048    let share_id = SharedThreadId::from_proto(request.session_id)
4049        .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4050
4051    let result = session.db().await.get_shared_thread(share_id).await?;
4052
4053    match result {
4054        Some((thread, username)) => {
4055            response.send(proto::GetSharedAgentThreadResponse {
4056                title: thread.title,
4057                thread_data: thread.data,
4058                sharer_username: username,
4059                created_at: thread.created_at.and_utc().to_rfc3339(),
4060            })?;
4061        }
4062        None => {
4063            return Err(anyhow!("Shared thread not found").into());
4064        }
4065    }
4066
4067    Ok(())
4068}
4069
4070pub trait ResultExt {
4071    type Ok;
4072
4073    fn trace_err(self) -> Option<Self::Ok>;
4074}
4075
4076impl<T, E> ResultExt for Result<T, E>
4077where
4078    E: std::fmt::Debug,
4079{
4080    type Ok = T;
4081
4082    #[track_caller]
4083    fn trace_err(self) -> Option<T> {
4084        match self {
4085            Ok(value) => Some(value),
4086            Err(error) => {
4087                tracing::error!("{:?}", error);
4088                None
4089            }
4090        }
4091    }
4092}